Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13    ServerHandler,
14    transport::streamable_http_server::{
15        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16    },
17};
18use rustls::RootCertStore;
19use tokio::net::TcpListener;
20use tokio_util::sync::CancellationToken;
21
22use crate::{
23    auth::{
24        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
25        build_rate_limiter, extract_mtls_identity,
26    },
27    error::McpxError,
28    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
29    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
30};
31
32/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
33/// at the public API boundary, flattening the chain via the alternate
34/// formatter so callers see the full causal path.
35#[allow(
36    clippy::needless_pass_by_value,
37    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
38)]
39fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
40    McpxError::Startup(format!("{e:#}"))
41}
42
43/// Map a `std::io::Error` produced during server startup into a public
44/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
45/// `From` impl here because startup-phase IO errors (bind, listener) are
46/// semantically distinct from request-time IO errors and should surface
47/// the originating operation in the message.
48#[allow(
49    clippy::needless_pass_by_value,
50    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
51)]
52fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
53    McpxError::Startup(format!("{op}: {e}"))
54}
55
56/// Async readiness check callback for the `/readyz` endpoint.
57///
58/// Returns a JSON object with at least a `"ready"` boolean.
59/// When `ready` is false, the endpoint returns HTTP 503.
60pub type ReadinessCheck =
61    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
62
63/// Per-header overrides for the OWASP security headers emitted by the
64/// global response middleware.
65///
66/// Each field follows a three-state semantic:
67///
68/// | Value         | Behaviour                                                |
69/// |---------------|----------------------------------------------------------|
70/// | `None`        | Use the built-in default (current behaviour).            |
71/// | `Some("")`    | **Omit** the header entirely from responses.             |
72/// | `Some(value)` | Emit `header: value`. Validated at config-load time.     |
73///
74/// All non-empty values are validated via
75/// [`axum::http::HeaderValue::from_str`] inside
76/// [`McpServerConfig::validate`]; invalid values fail fast before the
77/// server starts accepting traffic.
78///
79/// `Strict-Transport-Security` has an additional rule: the substring
80/// `preload` (case-insensitive) is rejected. Operators who want to
81/// commit to the HSTS preload list must do so via a future explicit
82/// builder method, not by smuggling it through this knob.
83#[derive(Debug, Clone, Default)]
84#[non_exhaustive]
85pub struct SecurityHeadersConfig {
86    /// Override for `X-Content-Type-Options`. Default: `nosniff`.
87    pub x_content_type_options: Option<String>,
88    /// Override for `X-Frame-Options`. Default: `deny`.
89    pub x_frame_options: Option<String>,
90    /// Override for `Cache-Control`. Default: `no-store, max-age=0`.
91    pub cache_control: Option<String>,
92    /// Override for `Referrer-Policy`. Default: `no-referrer`.
93    pub referrer_policy: Option<String>,
94    /// Override for `Cross-Origin-Opener-Policy`. Default: `same-origin`.
95    pub cross_origin_opener_policy: Option<String>,
96    /// Override for `Cross-Origin-Resource-Policy`. Default: `same-origin`.
97    pub cross_origin_resource_policy: Option<String>,
98    /// Override for `Cross-Origin-Embedder-Policy`. Default: `require-corp`.
99    pub cross_origin_embedder_policy: Option<String>,
100    /// Override for `Permissions-Policy`. Default:
101    /// `accelerometer=(), camera=(), geolocation=(), microphone=()`.
102    pub permissions_policy: Option<String>,
103    /// Override for `X-Permitted-Cross-Domain-Policies`. Default: `none`.
104    pub x_permitted_cross_domain_policies: Option<String>,
105    /// Override for `Content-Security-Policy`. Default:
106    /// `default-src 'none'; frame-ancestors 'none'`.
107    pub content_security_policy: Option<String>,
108    /// Override for `X-DNS-Prefetch-Control`. Default: `off`.
109    pub x_dns_prefetch_control: Option<String>,
110    /// Override for `Strict-Transport-Security`. Default (TLS only):
111    /// `max-age=63072000; includeSubDomains`. Only emitted when TLS is
112    /// active; the override is ignored on plaintext deployments. The
113    /// substring `preload` (any case) is rejected by the validator.
114    pub strict_transport_security: Option<String>,
115}
116
117/// Configuration for the MCP server.
118#[allow(
119    missing_debug_implementations,
120    reason = "contains callback/trait objects that don't impl Debug"
121)]
122#[allow(
123    clippy::struct_excessive_bools,
124    reason = "server configuration naturally has many boolean feature flags"
125)]
126#[non_exhaustive]
127pub struct McpServerConfig {
128    /// Socket address the MCP HTTP server binds to.
129    #[deprecated(
130        since = "0.13.0",
131        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
132    )]
133    pub bind_addr: String,
134    /// Server name advertised via MCP `initialize`.
135    #[deprecated(
136        since = "0.13.0",
137        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
138    )]
139    pub name: String,
140    /// Server version advertised via MCP `initialize`.
141    #[deprecated(
142        since = "0.13.0",
143        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
144    )]
145    pub version: String,
146    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
147    #[deprecated(
148        since = "0.13.0",
149        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
150    )]
151    pub tls_cert_path: Option<PathBuf>,
152    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
153    #[deprecated(
154        since = "0.13.0",
155        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
156    )]
157    pub tls_key_path: Option<PathBuf>,
158    /// Optional authentication config. When `Some` and `enabled`, auth
159    /// is enforced on `/mcp`. `/healthz` is always open.
160    #[deprecated(
161        since = "0.13.0",
162        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
163    )]
164    pub auth: Option<AuthConfig>,
165    /// Optional RBAC policy. When present and enabled, tool calls are
166    /// checked against the policy after authentication.
167    #[deprecated(
168        since = "0.13.0",
169        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
170    )]
171    pub rbac: Option<Arc<RbacPolicy>>,
172    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
173    /// When empty and `public_url` is set, the origin is auto-derived from
174    /// the public URL. When both are empty, only requests with no Origin
175    /// header are accepted.
176    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
177    #[deprecated(
178        since = "0.13.0",
179        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
180    )]
181    pub allowed_origins: Vec<String>,
182    /// Maximum tool invocations per source IP per minute.
183    /// When set, enforced on every `tools/call` request.
184    #[deprecated(
185        since = "0.13.0",
186        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
187    )]
188    pub tool_rate_limit: Option<u32>,
189    /// Optional readiness probe for `/readyz`.
190    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
191    #[deprecated(
192        since = "0.13.0",
193        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
194    )]
195    pub readiness_check: Option<ReadinessCheck>,
196    /// Maximum request body size in bytes. Default: 1 MiB.
197    /// Protects against oversized payloads causing OOM.
198    #[deprecated(
199        since = "0.13.0",
200        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
201    )]
202    pub max_request_body: usize,
203    /// Request processing timeout. Default: 120s.
204    /// Requests exceeding this duration receive 408 Request Timeout.
205    #[deprecated(
206        since = "0.13.0",
207        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
208    )]
209    pub request_timeout: Duration,
210    /// Graceful shutdown timeout. Default: 30s.
211    /// After the shutdown signal, in-flight requests have this long to finish.
212    #[deprecated(
213        since = "0.13.0",
214        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
215    )]
216    pub shutdown_timeout: Duration,
217    /// Idle timeout for MCP sessions. Sessions with no activity for this
218    /// duration are closed automatically. Default: 20 minutes.
219    #[deprecated(
220        since = "0.13.0",
221        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
222    )]
223    pub session_idle_timeout: Duration,
224    /// Interval for SSE keep-alive pings. Prevents proxies and load
225    /// balancers from killing idle connections. Default: 15 seconds.
226    #[deprecated(
227        since = "0.13.0",
228        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
229    )]
230    pub sse_keep_alive: Duration,
231    /// Callback invoked once the server is built, delivering a
232    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
233    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
234    #[deprecated(
235        since = "0.13.0",
236        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
237    )]
238    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
239    /// Additional application-specific routes merged into the top-level
240    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
241    /// so the application is responsible for its own auth on them.
242    #[deprecated(
243        since = "0.13.0",
244        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
245    )]
246    pub extra_router: Option<axum::Router>,
247    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
248    /// When set, OAuth metadata endpoints advertise this URL instead of
249    /// the listen address. Required when binding `0.0.0.0` behind a
250    /// reverse proxy or inside a container.
251    #[deprecated(
252        since = "0.13.0",
253        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
254    )]
255    pub public_url: Option<String>,
256    /// Log inbound HTTP request headers at DEBUG level.
257    /// Sensitive values remain redacted.
258    #[deprecated(
259        since = "0.13.0",
260        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
261    )]
262    pub log_request_headers: bool,
263    /// Enable gzip/br response compression on MCP responses.
264    /// Defaults to `false` to preserve existing behaviour.
265    #[deprecated(
266        since = "0.13.0",
267        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
268    )]
269    pub compression_enabled: bool,
270    /// Minimum response body size (in bytes) before compression kicks in.
271    /// Only used when `compression_enabled` is true. Default: 1024.
272    #[deprecated(
273        since = "0.13.0",
274        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
275    )]
276    pub compression_min_size: u16,
277    /// Global cap on in-flight HTTP requests across the whole server.
278    /// When `Some`, requests over the cap receive 503 Service Unavailable
279    /// via `tower::load_shed`. Default: `None` (unlimited).
280    #[deprecated(
281        since = "0.13.0",
282        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
283    )]
284    pub max_concurrent_requests: Option<usize>,
285    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
286    /// configured and `enabled`. Default: `false`.
287    #[deprecated(
288        since = "0.13.0",
289        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
290    )]
291    pub admin_enabled: bool,
292    /// RBAC role required to access admin endpoints. Default: `"admin"`.
293    #[deprecated(
294        since = "0.13.0",
295        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
296    )]
297    pub admin_role: String,
298    /// Enable Prometheus metrics endpoint on a separate listener.
299    /// Requires the `metrics` crate feature.
300    #[cfg(feature = "metrics")]
301    #[deprecated(
302        since = "0.13.0",
303        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
304    )]
305    pub metrics_enabled: bool,
306    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
307    #[cfg(feature = "metrics")]
308    #[deprecated(
309        since = "0.13.0",
310        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
311    )]
312    pub metrics_bind: String,
313    /// Per-header overrides for the OWASP security headers emitted by
314    /// the global response middleware. See [`SecurityHeadersConfig`]
315    /// for the three-state semantic and validation rules.
316    #[deprecated(
317        since = "1.5.0",
318        note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
319    )]
320    pub security_headers: SecurityHeadersConfig,
321}
322
323/// Marker that wraps a value proven to satisfy its validation
324/// contract.
325///
326/// The only way to obtain `Validated<McpServerConfig>` is by calling
327/// [`McpServerConfig::validate`], which is the contract enforced at
328/// the type level by [`serve`] and [`serve_with_listener`]. The
329/// inner field is private, so downstream code cannot bypass
330/// validation by hand-constructing the wrapper.
331///
332/// `Validated<T>` derefs to `&T` for read-only access. To mutate,
333/// recover the raw value with [`Validated::into_inner`] and
334/// re-validate.
335///
336/// # Example
337///
338/// ```no_run
339/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
340/// use rmcp::handler::server::ServerHandler;
341/// use rmcp::model::{ServerCapabilities, ServerInfo};
342///
343/// #[derive(Clone)]
344/// struct H;
345/// impl ServerHandler for H {
346///     fn get_info(&self) -> ServerInfo {
347///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
348///     }
349/// }
350///
351/// # async fn example() -> rmcp_server_kit::Result<()> {
352/// let config: Validated<McpServerConfig> =
353///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
354/// serve(config, || H).await
355/// # }
356/// ```
357///
358/// Forgetting `.validate()?` is a compile error:
359///
360/// ```compile_fail
361/// use rmcp_server_kit::transport::{McpServerConfig, serve};
362/// use rmcp::handler::server::ServerHandler;
363/// use rmcp::model::{ServerCapabilities, ServerInfo};
364///
365/// #[derive(Clone)]
366/// struct H;
367/// impl ServerHandler for H {
368///     fn get_info(&self) -> ServerInfo {
369///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
370///     }
371/// }
372///
373/// # async fn example() -> rmcp_server_kit::Result<()> {
374/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
375/// // Missing `.validate()?` -> mismatched types: expected
376/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
377/// serve(config, || H).await
378/// # }
379/// ```
380#[allow(
381    missing_debug_implementations,
382    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
383)]
384pub struct Validated<T>(T);
385
386impl<T> std::fmt::Debug for Validated<T> {
387    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        f.debug_struct("Validated").finish_non_exhaustive()
389    }
390}
391
392impl<T> Validated<T> {
393    /// Borrow the inner value.
394    #[must_use]
395    pub fn as_inner(&self) -> &T {
396        &self.0
397    }
398
399    /// Recover the raw value, discarding the validation proof.
400    ///
401    /// Re-validate before re-using the value with [`serve`] or
402    /// [`serve_with_listener`].
403    #[must_use]
404    pub fn into_inner(self) -> T {
405        self.0
406    }
407}
408
409impl<T> std::ops::Deref for Validated<T> {
410    type Target = T;
411
412    fn deref(&self) -> &T {
413        &self.0
414    }
415}
416
417#[allow(
418    deprecated,
419    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
420)]
421impl McpServerConfig {
422    /// Create a new server configuration with the given bind address,
423    /// server name, and version. All other fields use safe defaults.
424    ///
425    /// Use the chainable `with_*` / `enable_*` builder methods to
426    /// customize. Call [`McpServerConfig::validate`] to obtain a
427    /// [`Validated<McpServerConfig>`] proof token, which is required by
428    /// [`serve`] and [`serve_with_listener`].
429    #[must_use]
430    pub fn new(
431        bind_addr: impl Into<String>,
432        name: impl Into<String>,
433        version: impl Into<String>,
434    ) -> Self {
435        Self {
436            bind_addr: bind_addr.into(),
437            name: name.into(),
438            version: version.into(),
439            tls_cert_path: None,
440            tls_key_path: None,
441            auth: None,
442            rbac: None,
443            allowed_origins: Vec::new(),
444            tool_rate_limit: None,
445            readiness_check: None,
446            max_request_body: 1024 * 1024,
447            request_timeout: Duration::from_mins(2),
448            shutdown_timeout: Duration::from_secs(30),
449            session_idle_timeout: Duration::from_mins(20),
450            sse_keep_alive: Duration::from_secs(15),
451            on_reload_ready: None,
452            extra_router: None,
453            public_url: None,
454            log_request_headers: false,
455            compression_enabled: false,
456            compression_min_size: 1024,
457            max_concurrent_requests: None,
458            admin_enabled: false,
459            admin_role: "admin".to_owned(),
460            #[cfg(feature = "metrics")]
461            metrics_enabled: false,
462            #[cfg(feature = "metrics")]
463            metrics_bind: "127.0.0.1:9090".into(),
464            security_headers: SecurityHeadersConfig::default(),
465        }
466    }
467
468    // ---------------------------------------------------------------
469    // Builder methods (fluent, consume + return self).
470    //
471    // Each method is `#[must_use]` because dropping the returned
472    // `McpServerConfig` discards the configuration change.
473    // ---------------------------------------------------------------
474
475    /// Attach an authentication configuration. Required for
476    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
477    #[must_use]
478    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
479        self.auth = Some(auth);
480        self
481    }
482
483    /// Override one or more of the OWASP security headers emitted on
484    /// every response. See [`SecurityHeadersConfig`] for the three-state
485    /// semantic (`None` = default, `Some("")` = omit, `Some(v)` =
486    /// override). Values are validated by [`Self::validate`].
487    #[must_use]
488    pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
489        self.security_headers = headers;
490        self
491    }
492
493    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
494    /// final port is only known after pre-binding an ephemeral listener
495    /// (tests, dynamic-port deployments).
496    #[must_use]
497    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
498        self.bind_addr = addr.into();
499        self
500    }
501
502    /// Attach an RBAC policy. Tool calls are checked against the policy
503    /// after authentication.
504    #[must_use]
505    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
506        self.rbac = Some(rbac);
507        self
508    }
509
510    /// Configure TLS by providing the certificate and private key paths
511    /// (PEM). Both must be readable at startup. Without this call, the
512    /// server runs plain HTTP.
513    #[must_use]
514    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
515        self.tls_cert_path = Some(cert_path.into());
516        self.tls_key_path = Some(key_path.into());
517        self
518    }
519
520    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
521    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
522    /// a container so OAuth metadata and auto-derived origins resolve correctly.
523    #[must_use]
524    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
525        self.public_url = Some(url.into());
526        self
527    }
528
529    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
530    /// When empty and [`with_public_url`](Self::with_public_url) is set,
531    /// the origin is auto-derived.
532    #[must_use]
533    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
534    where
535        I: IntoIterator<Item = S>,
536        S: Into<String>,
537    {
538        self.allowed_origins = origins.into_iter().map(Into::into).collect();
539        self
540    }
541
542    /// Merge an additional axum router at the top level. Routes added
543    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
544    /// for its own protection.
545    #[must_use]
546    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
547        self.extra_router = Some(router);
548        self
549    }
550
551    /// Install an async readiness probe for `/readyz`. Without this call,
552    /// `/readyz` mirrors `/healthz` (always 200 OK).
553    #[must_use]
554    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
555        self.readiness_check = Some(check);
556        self
557    }
558
559    /// Override the maximum request body (bytes). Must be `> 0`.
560    /// Default: 1 MiB.
561    #[must_use]
562    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
563        self.max_request_body = bytes;
564        self
565    }
566
567    /// Override the per-request processing timeout. Default: 2 minutes.
568    #[must_use]
569    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
570        self.request_timeout = timeout;
571        self
572    }
573
574    /// Override the graceful shutdown grace period. Default: 30 seconds.
575    #[must_use]
576    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
577        self.shutdown_timeout = timeout;
578        self
579    }
580
581    /// Override the MCP session idle timeout. Default: 20 minutes.
582    #[must_use]
583    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
584        self.session_idle_timeout = timeout;
585        self
586    }
587
588    /// Override the SSE keep-alive interval. Default: 15 seconds.
589    #[must_use]
590    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
591        self.sse_keep_alive = interval;
592        self
593    }
594
595    /// Cap the global number of in-flight HTTP requests via
596    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
597    /// Default: unlimited.
598    #[must_use]
599    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
600        self.max_concurrent_requests = Some(limit);
601        self
602    }
603
604    /// Cap tool invocations per source IP per minute. Enforced on every
605    /// `tools/call` request.
606    #[must_use]
607    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
608        self.tool_rate_limit = Some(per_minute);
609        self
610    }
611
612    /// Register a callback that receives the [`ReloadHandle`] after the
613    /// server is built. Use it to wire SIGHUP-style hot reloads of API
614    /// keys and RBAC policy.
615    #[must_use]
616    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
617    where
618        F: FnOnce(ReloadHandle) + Send + 'static,
619    {
620        self.on_reload_ready = Some(Box::new(callback));
621        self
622    }
623
624    /// Enable gzip/brotli response compression on MCP responses.
625    /// `min_size` is the smallest body size (bytes) eligible for
626    /// compression. Default min size: 1024.
627    #[must_use]
628    pub fn enable_compression(mut self, min_size: u16) -> Self {
629        self.compression_enabled = true;
630        self.compression_min_size = min_size;
631        self
632    }
633
634    /// Enable `/admin/*` diagnostic endpoints. Requires
635    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
636    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
637    /// role gate (default: `"admin"`).
638    #[must_use]
639    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
640        self.admin_enabled = true;
641        self.admin_role = role.into();
642        self
643    }
644
645    /// Log inbound HTTP request headers at DEBUG level. Sensitive
646    /// values remain redacted by the logging layer.
647    #[must_use]
648    pub fn enable_request_header_logging(mut self) -> Self {
649        self.log_request_headers = true;
650        self
651    }
652
653    /// Enable the Prometheus metrics listener on `bind` (e.g.
654    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
655    #[cfg(feature = "metrics")]
656    #[must_use]
657    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
658        self.metrics_enabled = true;
659        self.metrics_bind = bind.into();
660        self
661    }
662
663    /// Validate the configuration and consume `self`, returning a
664    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
665    /// and [`serve_with_listener`]. This is the only way to construct
666    /// `Validated<McpServerConfig>`, so the type system guarantees
667    /// validation has run before the server starts.
668    ///
669    /// Checks:
670    ///
671    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
672    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
673    ///    be unset.
674    /// 3. `bind_addr` must parse as a [`SocketAddr`].
675    /// 4. `public_url`, when set, must start with `http://` or `https://`.
676    /// 5. Each entry in `allowed_origins` must start with `http://` or
677    ///    `https://`.
678    /// 6. `max_request_body` must be greater than zero.
679    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
680    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
681    ///    `proxy.token_url`, `proxy.introspection_url`,
682    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
683    ///    and use the `https` scheme. Set
684    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
685    ///    targets (strongly discouraged in production - see the field-level
686    ///    docs for the threat model).
687    ///
688    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
689    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
690    ///
691    /// # Errors
692    ///
693    /// Returns [`McpxError::Config`] with a human-readable message on
694    /// the first validation failure.
695    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
696        self.check()?;
697        Ok(Validated(self))
698    }
699
700    /// Run the validation checks without consuming `self`. Used by
701    /// internal call sites (e.g. tests) that need to inspect a config
702    /// without taking ownership.
703    fn check(&self) -> Result<(), McpxError> {
704        // 1. admin <-> auth dependency. Mirrors the runtime check in
705        //    `build_app_router`: admin endpoints require an auth state,
706        //    which is built only when `auth` is `Some` *and* `enabled`.
707        if self.admin_enabled {
708            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
709            if !auth_enabled {
710                return Err(McpxError::Config(
711                    "admin_enabled=true requires auth to be configured and enabled".into(),
712                ));
713            }
714        }
715
716        // 2. TLS cert / key must be paired
717        match (&self.tls_cert_path, &self.tls_key_path) {
718            (Some(_), None) => {
719                return Err(McpxError::Config(
720                    "tls_cert_path is set but tls_key_path is missing".into(),
721                ));
722            }
723            (None, Some(_)) => {
724                return Err(McpxError::Config(
725                    "tls_key_path is set but tls_cert_path is missing".into(),
726                ));
727            }
728            _ => {}
729        }
730
731        // 3. bind_addr parses
732        if self.bind_addr.parse::<SocketAddr>().is_err() {
733            return Err(McpxError::Config(format!(
734                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
735                self.bind_addr
736            )));
737        }
738
739        // 4. public_url scheme
740        if let Some(ref url) = self.public_url
741            && !(url.starts_with("http://") || url.starts_with("https://"))
742        {
743            return Err(McpxError::Config(format!(
744                "public_url {url:?} must start with http:// or https://"
745            )));
746        }
747
748        // 5. allowed_origins scheme
749        for origin in &self.allowed_origins {
750            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
751                return Err(McpxError::Config(format!(
752                    "allowed_origins entry {origin:?} must start with http:// or https://"
753                )));
754            }
755        }
756
757        // 6. max_request_body > 0
758        if self.max_request_body == 0 {
759            return Err(McpxError::Config(
760                "max_request_body must be greater than zero".into(),
761            ));
762        }
763
764        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
765        #[cfg(feature = "oauth")]
766        if let Some(auth_cfg) = &self.auth
767            && let Some(oauth_cfg) = &auth_cfg.oauth
768        {
769            oauth_cfg.validate()?;
770        }
771
772        // 8. Security-header overrides parse as valid HTTP header values,
773        //    and HSTS does not smuggle in a `preload` directive.
774        validate_security_headers(&self.security_headers)?;
775
776        // 9. max_concurrent_requests must be > 0 when set. Zero would
777        //    deadlock the global concurrency limiter and reject every
778        //    request. Mirrors the TOML-side check in `src/config.rs`.
779        if let Some(0) = self.max_concurrent_requests {
780            return Err(McpxError::Config(
781                "max_concurrent_requests must be greater than zero when set".into(),
782            ));
783        }
784
785        // 10. Auth rate-limit `max_tracked_keys` must be > 0. A zero cap
786        //     would force `BoundedKeyedLimiter` to evict on every insert
787        //     and effectively disable rate limiting.
788        if let Some(auth_cfg) = &self.auth
789            && let Some(rl) = &auth_cfg.rate_limit
790            && rl.max_tracked_keys == 0
791        {
792            return Err(McpxError::Config(
793                "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
794            ));
795        }
796
797        Ok(())
798    }
799}
800
801/// Handle for hot-reloading server configuration without restart.
802///
803/// Obtained via [`McpServerConfig::on_reload_ready`].
804/// All swap operations are lock-free and wait-free -- in-flight requests
805/// finish with the old values while new requests see the update immediately.
806#[allow(
807    missing_debug_implementations,
808    reason = "contains Arc<AuthState> with non-Debug fields"
809)]
810pub struct ReloadHandle {
811    auth: Option<Arc<AuthState>>,
812    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
813    crl_set: Option<Arc<CrlSet>>,
814}
815
816impl ReloadHandle {
817    /// Atomically replace the API key list used by the auth middleware.
818    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
819        if let Some(ref auth) = self.auth {
820            auth.reload_keys(keys);
821        }
822    }
823
824    /// Atomically replace the RBAC policy used by the RBAC middleware.
825    pub fn reload_rbac(&self, policy: RbacPolicy) {
826        if let Some(ref rbac) = self.rbac {
827            rbac.store(Arc::new(policy));
828            tracing::info!("RBAC policy reloaded");
829        }
830    }
831
832    /// Force an immediate refresh of all cached mTLS CRLs.
833    ///
834    /// # Errors
835    ///
836    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
837    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
838        let Some(ref crl_set) = self.crl_set else {
839            return Err(McpxError::Config(
840                "CRL refresh requested but mTLS CRL support is not configured".into(),
841            ));
842        };
843
844        crl_set.force_refresh().await
845    }
846}
847
848/// Generic MCP HTTP server.
849///
850/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
851/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
852/// with TLS (rustls). Optionally supports mTLS client certificate auth.
853///
854/// # Errors
855///
856/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
857/// or the server fails.
858// TODO(refactor): cognitive complexity reduced from 111/25 to 83/25 by
859// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
860// Remaining flow is a linear router builder: middleware layering, feature-
861// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
862// would require threading many `&mut Router` helpers and hurt readability
863// of the layer order (which is security-relevant and must stay visible).
864#[allow(clippy::too_many_lines, clippy::cognitive_complexity)]
865/// Internal bundle of values produced by [`build_app_router`] and
866/// consumed by [`serve`] / [`serve_with_listener`] when driving the
867/// HTTP listener.
868struct AppRunParams {
869    /// TLS cert/key paths when TLS is configured.
870    tls_paths: Option<(PathBuf, PathBuf)>,
871    /// mTLS configuration when mutual-TLS auth is enabled.
872    mtls_config: Option<MtlsConfig>,
873    /// Graceful shutdown drain window.
874    shutdown_timeout: Duration,
875    /// Shared auth state used by hot-reload callbacks.
876    auth_state: Option<Arc<AuthState>>,
877    /// Hot-reloadable RBAC state used by reload callbacks.
878    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
879    /// Optional callback that receives the final [`ReloadHandle`].
880    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
881    /// Server-internal cancellation token. Cancelled by [`run_server`]
882    /// once the shutdown trigger fires (so rmcp's child token also
883    /// fires, terminating in-flight MCP sessions).
884    ct: CancellationToken,
885    /// `"http"` or `"https"` -- used only for boot-time logging.
886    scheme: &'static str,
887    /// Server name -- used only for boot-time logging.
888    name: String,
889}
890
891/// Build the full application axum [`axum::Router`] (MCP route +
892/// middleware stack + admin + OAuth + health endpoints + security
893/// headers + CORS + compression + concurrency limit + origin check)
894/// and the [`AppRunParams`] needed to drive it.
895///
896/// This is the shared core of [`serve`] and [`serve_with_listener`].
897/// It performs *no* network I/O: callers are responsible for binding
898/// (or accepting a pre-bound) [`TcpListener`] and invoking
899/// [`run_server`].
900#[allow(
901    clippy::cognitive_complexity,
902    reason = "router assembly is intrinsically sequential; splitting harms readability"
903)]
904#[allow(
905    deprecated,
906    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
907)]
908fn build_app_router<H, F>(
909    mut config: McpServerConfig,
910    handler_factory: F,
911) -> anyhow::Result<(axum::Router, AppRunParams)>
912where
913    H: ServerHandler + 'static,
914    F: Fn() -> H + Send + Sync + Clone + 'static,
915{
916    let ct = CancellationToken::new();
917
918    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
919    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
920
921    let mcp_service = StreamableHttpService::new(
922        move || Ok(handler_factory()),
923        {
924            let mut mgr = LocalSessionManager::default();
925            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
926            mgr.into()
927        },
928        StreamableHttpServerConfig::default()
929            .with_allowed_hosts(allowed_hosts)
930            .with_sse_keep_alive(Some(config.sse_keep_alive))
931            .with_cancellation_token(ct.child_token()),
932    );
933
934    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
935    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
936
937    // Build auth state eagerly when auth is configured so we can wire both
938    // the auth middleware *and* the optional admin router against the same
939    // state. The middleware itself is installed further down in layer order.
940    let auth_state: Option<Arc<AuthState>> = match config.auth {
941        Some(ref auth_config) if auth_config.enabled => {
942            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
943            let pre_auth_limiter = auth_config
944                .rate_limit
945                .as_ref()
946                .map(crate::auth::build_pre_auth_limiter);
947
948            #[cfg(feature = "oauth")]
949            let jwks_cache = auth_config
950                .oauth
951                .as_ref()
952                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
953                .transpose()
954                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
955
956            Some(Arc::new(AuthState {
957                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
958                rate_limiter,
959                pre_auth_limiter,
960                #[cfg(feature = "oauth")]
961                jwks_cache,
962                seen_identities: std::sync::Mutex::new(std::collections::HashSet::new()),
963                counters: crate::auth::AuthCounters::default(),
964            }))
965        }
966        _ => None,
967    };
968
969    // Build the RBAC policy swap early so the admin router and the later
970    // RBAC middleware layer share the same hot-reloadable state.
971    let rbac_swap = Arc::new(ArcSwap::new(
972        config
973            .rbac
974            .clone()
975            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
976    ));
977
978    // Optional /admin/* diagnostic routes. Merged BEFORE the
979    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
980    if config.admin_enabled {
981        let Some(ref auth_state_ref) = auth_state else {
982            return Err(anyhow::anyhow!(
983                "admin_enabled=true requires auth to be configured and enabled"
984            ));
985        };
986        let admin_state = crate::admin::AdminState {
987            started_at: std::time::Instant::now(),
988            name: config.name.clone(),
989            version: config.version.clone(),
990            auth: Some(Arc::clone(auth_state_ref)),
991            rbac: Arc::clone(&rbac_swap),
992        };
993        let admin_cfg = crate::admin::AdminConfig {
994            role: config.admin_role.clone(),
995        };
996        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
997        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
998    }
999
1000    // ----- Middleware order (CRITICAL: read carefully) ------------------
1001    //
1002    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
1003    // added is the OUTERMOST (runs first on a request). To achieve a
1004    // request-time flow of:
1005    //
1006    //   body-limit -> timeout -> auth -> rbac -> handler
1007    //
1008    // we add layers in the REVERSE order:
1009    //
1010    //   1. RBAC               (innermost, runs last before handler)
1011    //   2. auth               (parses identity, sets extension for RBAC)
1012    //   3. timeout            (bounds total request time)
1013    //   4. body-limit         (outermost on /mcp; caps payload before
1014    //                          anything else reads/buffers it)
1015    //
1016    // Origin validation is installed on the OUTER router (after the
1017    // /mcp router is merged in), so it also protects /healthz, /readyz,
1018    // /version, and any OAuth proxy endpoints.
1019    //
1020    // Rationale:
1021    // - Body-limit must be outermost on /mcp so RBAC (which reads the
1022    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
1023    // - Auth must run before RBAC because RBAC consumes
1024    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
1025    //   policy.
1026    // - Origin runs before auth so we reject cross-origin requests
1027    //   without spending Argon2 cycles on unauthenticated callers.
1028
1029    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
1030    // Always installed: even when RBAC is disabled, tool rate limiting may
1031    // be active (MCP spec: servers MUST rate limit tool invocations).
1032    {
1033        let tool_limiter: Option<Arc<ToolRateLimiter>> =
1034            config.tool_rate_limit.map(build_tool_rate_limiter);
1035
1036        if rbac_swap.load().is_enabled() {
1037            tracing::info!("RBAC enforcement enabled on /mcp");
1038        }
1039        if let Some(limit) = config.tool_rate_limit {
1040            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1041        }
1042
1043        let rbac_for_mw = Arc::clone(&rbac_swap);
1044        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1045            let p = rbac_for_mw.load_full();
1046            let tl = tool_limiter.clone();
1047            rbac_middleware(p, tl, req, next)
1048        }));
1049    }
1050
1051    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
1052    if let Some(ref auth_config) = config.auth
1053        && auth_config.enabled
1054    {
1055        let Some(ref state) = auth_state else {
1056            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1057        };
1058
1059        let methods: Vec<&str> = [
1060            auth_config.mtls.is_some().then_some("mTLS"),
1061            (!auth_config.api_keys.is_empty()).then_some("bearer"),
1062            #[cfg(feature = "oauth")]
1063            auth_config.oauth.is_some().then_some("oauth-jwt"),
1064        ]
1065        .into_iter()
1066        .flatten()
1067        .collect();
1068
1069        tracing::info!(
1070            methods = %methods.join(", "),
1071            api_keys = auth_config.api_keys.len(),
1072            "auth enabled on /mcp"
1073        );
1074
1075        let state_for_mw = Arc::clone(state);
1076        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1077            let s = Arc::clone(&state_for_mw);
1078            auth_middleware(s, req, next)
1079        }));
1080    }
1081
1082    // [3] Request timeout (returns 408 on expiry). Bounds total request
1083    // duration including auth + handler.
1084    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1085        axum::http::StatusCode::REQUEST_TIMEOUT,
1086        config.request_timeout,
1087    ));
1088
1089    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
1090    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
1091    // attempts to buffer or parse the body.
1092    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1093        config.max_request_body,
1094    ));
1095
1096    // Compute the effective allowed-origins list for the outer
1097    // origin-check layer (installed on the merged router below). When
1098    // `allowed_origins` is empty but `public_url` is set, auto-derive
1099    // the origin from the public URL so MCP clients (e.g. Claude Code)
1100    // that send `Origin: <server-url>` are accepted without explicit
1101    // config.
1102    let mut effective_origins = config.allowed_origins.clone();
1103    if effective_origins.is_empty()
1104        && let Some(ref url) = config.public_url
1105    {
1106        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1107        // Strip any path/query from the public URL.
1108        if let Some(scheme_end) = url.find("://") {
1109            let after_scheme = &url[scheme_end + 3..];
1110            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1111            let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
1112            tracing::info!(
1113                %origin,
1114                "auto-derived allowed origin from public_url"
1115            );
1116            effective_origins.push(origin);
1117        }
1118    }
1119    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1120    let cors_origins = Arc::clone(&allowed_origins);
1121    let log_request_headers = config.log_request_headers;
1122
1123    let readyz_route = if let Some(check) = config.readiness_check.take() {
1124        axum::routing::get(move || readyz(Arc::clone(&check)))
1125    } else {
1126        axum::routing::get(healthz)
1127    };
1128
1129    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1130    let mut router = axum::Router::new()
1131        .route("/healthz", axum::routing::get(healthz))
1132        .route("/readyz", readyz_route)
1133        .route(
1134            "/version",
1135            axum::routing::get({
1136                // Pre-serialize the version payload once at router-build
1137                // time. The handler then serves a cheap `Arc::clone` of the
1138                // immutable bytes per request, avoiding `serde_json::Value`
1139                // allocation + serialization on every `/version` hit.
1140                let payload_bytes: Arc<[u8]> =
1141                    serialize_version_payload(&config.name, &config.version);
1142                move || {
1143                    let p = Arc::clone(&payload_bytes);
1144                    async move {
1145                        (
1146                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1147                            p.to_vec(),
1148                        )
1149                    }
1150                }
1151            }),
1152        )
1153        .merge(mcp_router);
1154
1155    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1156    if let Some(extra) = config.extra_router.take() {
1157        router = router.merge(extra);
1158    }
1159
1160    // RFC 9728: Protected Resource Metadata endpoint.
1161    // When OAuth is configured, serve full metadata with authorization_servers.
1162    // Otherwise, serve a minimal document with just the resource URL and no
1163    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1164    // that the server exists but does NOT require OAuth authentication,
1165    // preventing them from gating the connection behind a broken auth flow.
1166    let server_url = if let Some(ref url) = config.public_url {
1167        url.trim_end_matches('/').to_owned()
1168    } else {
1169        let prm_scheme = if config.tls_cert_path.is_some() {
1170            "https"
1171        } else {
1172            "http"
1173        };
1174        format!("{prm_scheme}://{}", config.bind_addr)
1175    };
1176    let resource_url = format!("{server_url}/mcp");
1177
1178    #[cfg(feature = "oauth")]
1179    let prm_metadata = if let Some(ref auth_config) = config.auth
1180        && let Some(ref oauth_config) = auth_config.oauth
1181    {
1182        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1183    } else {
1184        serde_json::json!({ "resource": resource_url })
1185    };
1186    #[cfg(not(feature = "oauth"))]
1187    let prm_metadata = serde_json::json!({ "resource": resource_url });
1188
1189    router = router.route(
1190        "/.well-known/oauth-protected-resource",
1191        axum::routing::get(move || {
1192            let m = prm_metadata.clone();
1193            async move { axum::Json(m) }
1194        }),
1195    );
1196
1197    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1198    // /authorize, /token, /register, and authorization server metadata so
1199    // MCP clients can perform Authorization Code + PKCE against the upstream
1200    // IdP (e.g. Keycloak) transparently.
1201    #[cfg(feature = "oauth")]
1202    if let Some(ref auth_config) = config.auth
1203        && let Some(ref oauth_config) = auth_config.oauth
1204        && oauth_config.proxy.is_some()
1205    {
1206        router =
1207            install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1208    }
1209
1210    // OWASP security response headers (applied to all responses).
1211    // HSTS is conditional on TLS being configured.
1212    let is_tls = config.tls_cert_path.is_some();
1213    let security_headers_cfg = Arc::new(config.security_headers.clone());
1214    router = router.layer(axum::middleware::from_fn(move |req, next| {
1215        let cfg = Arc::clone(&security_headers_cfg);
1216        security_headers_middleware(is_tls, cfg, req, next)
1217    }));
1218
1219    // CORS preflight layer (required for browser-based MCP clients).
1220    // Uses the same effective origins as the origin check middleware
1221    // (including auto-derived origin from public_url).
1222    if !cors_origins.is_empty() {
1223        let cors = tower_http::cors::CorsLayer::new()
1224            .allow_origin(
1225                cors_origins
1226                    .iter()
1227                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1228                    .collect::<Vec<_>>(),
1229            )
1230            .allow_methods([
1231                axum::http::Method::GET,
1232                axum::http::Method::POST,
1233                axum::http::Method::OPTIONS,
1234            ])
1235            .allow_headers([
1236                axum::http::header::CONTENT_TYPE,
1237                axum::http::header::AUTHORIZATION,
1238            ]);
1239        router = router.layer(cors);
1240    }
1241
1242    // Optional response compression (gzip + brotli). Skips small bodies
1243    // to avoid overhead. Applied after CORS so preflight responses remain
1244    // uncompressed.
1245    if config.compression_enabled {
1246        use tower_http::compression::Predicate as _;
1247        let predicate = tower_http::compression::DefaultPredicate::new().and(
1248            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1249        );
1250        router = router.layer(
1251            tower_http::compression::CompressionLayer::new()
1252                .gzip(true)
1253                .br(true)
1254                .compress_when(predicate),
1255        );
1256        tracing::info!(
1257            min_size = config.compression_min_size,
1258            "response compression enabled (gzip, br)"
1259        );
1260    }
1261
1262    // Optional global concurrency cap. `load_shed` converts the
1263    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1264    if let Some(max) = config.max_concurrent_requests {
1265        let overload_handler = tower::ServiceBuilder::new()
1266            .layer(axum::error_handling::HandleErrorLayer::new(
1267                |_err: tower::BoxError| async {
1268                    (
1269                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1270                        axum::Json(serde_json::json!({
1271                            "error": "overloaded",
1272                            "error_description": "server is at capacity, retry later"
1273                        })),
1274                    )
1275                },
1276            ))
1277            .layer(tower::load_shed::LoadShedLayer::new())
1278            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1279        router = router.layer(overload_handler);
1280        tracing::info!(max, "global concurrency limit enabled");
1281    }
1282
1283    // JSON fallback for unmatched routes. Without this, axum returns
1284    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1285    // when they probe OAuth endpoints like /authorize or /token.
1286    router = router.fallback(|| async {
1287        (
1288            axum::http::StatusCode::NOT_FOUND,
1289            axum::Json(serde_json::json!({
1290                "error": "not_found",
1291                "error_description": "The requested endpoint does not exist"
1292            })),
1293        )
1294    });
1295
1296    // Prometheus metrics: recording middleware + separate listener.
1297    #[cfg(feature = "metrics")]
1298    if config.metrics_enabled {
1299        let metrics = Arc::new(
1300            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1301        );
1302        let m = Arc::clone(&metrics);
1303        router = router.layer(axum::middleware::from_fn(
1304            move |req: Request<Body>, next: Next| {
1305                let m = Arc::clone(&m);
1306                metrics_middleware(m, req, next)
1307            },
1308        ));
1309        let metrics_bind = config.metrics_bind.clone();
1310        let metrics_shutdown = ct.clone();
1311        tokio::spawn(async move {
1312            if let Err(e) =
1313                crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1314            {
1315                tracing::error!("metrics listener failed: {e}");
1316            }
1317        });
1318    }
1319
1320    // Origin validation layer (MCP spec: servers MUST validate the
1321    // Origin header to prevent DNS rebinding attacks). Installed as the
1322    // OUTERMOST layer on the OUTER router so it protects ALL routes
1323    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1324    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1325    // reject cross-origin attackers without spending Argon2 cycles.
1326    //
1327    // Origin-less requests (e.g. server-to-server probes, curl, native
1328    // MCP clients) are permitted; only requests with an Origin header
1329    // that does not match `effective_origins` are rejected.
1330    router = router.layer(axum::middleware::from_fn(move |req, next| {
1331        let origins = Arc::clone(&allowed_origins);
1332        origin_check_middleware(origins, log_request_headers, req, next)
1333    }));
1334
1335    let scheme = if config.tls_cert_path.is_some() {
1336        "https"
1337    } else {
1338        "http"
1339    };
1340
1341    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1342        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1343        _ => None,
1344    };
1345    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1346
1347    Ok((
1348        router,
1349        AppRunParams {
1350            tls_paths,
1351            mtls_config,
1352            shutdown_timeout: config.shutdown_timeout,
1353            auth_state,
1354            rbac_swap,
1355            on_reload_ready: config.on_reload_ready.take(),
1356            ct,
1357            scheme,
1358            name: config.name.clone(),
1359        },
1360    ))
1361}
1362
1363/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1364/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1365///
1366/// This is the standard entry point for production deployments. For
1367/// deterministic shutdown control (e.g. integration tests), see
1368/// [`serve_with_listener`].
1369///
1370/// The configuration must be validated first via
1371/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1372/// token. This typestate guarantees, at compile time, that the server
1373/// never starts with an invalid configuration.
1374///
1375/// # Errors
1376///
1377/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1378/// fails, or if the underlying axum server returns an error.
1379pub async fn serve<H, F>(
1380    config: Validated<McpServerConfig>,
1381    handler_factory: F,
1382) -> Result<(), McpxError>
1383where
1384    H: ServerHandler + 'static,
1385    F: Fn() -> H + Send + Sync + Clone + 'static,
1386{
1387    let config = config.into_inner();
1388    #[allow(
1389        deprecated,
1390        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1391    )]
1392    let bind_addr = config.bind_addr.clone();
1393    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1394
1395    let listener = TcpListener::bind(&bind_addr)
1396        .await
1397        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1398    log_listening(&params.name, params.scheme, &bind_addr);
1399
1400    run_server(
1401        router,
1402        listener,
1403        params.tls_paths,
1404        params.mtls_config,
1405        params.shutdown_timeout,
1406        params.auth_state,
1407        params.rbac_swap,
1408        params.on_reload_ready,
1409        params.ct,
1410    )
1411    .await
1412    .map_err(anyhow_to_startup)
1413}
1414
1415/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1416/// readiness signalling and external shutdown control.
1417///
1418/// This variant is intended for **deterministic integration tests** and
1419/// for embedders that need to bind the listening socket themselves
1420/// (e.g. systemd socket activation). Compared to [`serve`]:
1421///
1422/// * The caller passes a `TcpListener` that is already bound. This
1423///   eliminates the bind race in tests that previously required
1424///   poll-the-`/healthz`-loop start-up detection.
1425/// * `ready_tx`, when `Some`, receives the socket's
1426///   [`SocketAddr`] *after* the router is built and immediately before
1427///   the server starts accepting connections. Tests can `await` the
1428///   matching `oneshot::Receiver` to know exactly when it is safe to
1429///   issue requests.
1430/// * `shutdown`, when `Some`, gives the caller a
1431///   [`CancellationToken`] that triggers the same graceful-shutdown
1432///   path as a real OS signal. This avoids cross-platform issues with
1433///   sending real `SIGTERM` from tests on Windows.
1434///
1435/// All three optional parameters degrade gracefully: if `ready_tx` is
1436/// `None`, no signal is sent; if `shutdown` is `None`, the server only
1437/// stops on an OS signal (just like [`serve`]).
1438///
1439/// # Errors
1440///
1441/// Returns [`McpxError::Startup`] if router construction fails, if reading
1442/// the listener's `local_addr()` fails, or if the underlying axum
1443/// server returns an error.
1444pub async fn serve_with_listener<H, F>(
1445    listener: TcpListener,
1446    config: Validated<McpServerConfig>,
1447    handler_factory: F,
1448    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1449    shutdown: Option<CancellationToken>,
1450) -> Result<(), McpxError>
1451where
1452    H: ServerHandler + 'static,
1453    F: Fn() -> H + Send + Sync + Clone + 'static,
1454{
1455    let config = config.into_inner();
1456    let local_addr = listener
1457        .local_addr()
1458        .map_err(|e| io_to_startup("listener.local_addr", e))?;
1459    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1460
1461    log_listening(&params.name, params.scheme, &local_addr.to_string());
1462
1463    // Forward external shutdown into the server-internal cancellation
1464    // token so `run_server`'s shutdown trigger picks it up alongside
1465    // any real OS signal.
1466    if let Some(external) = shutdown {
1467        let internal = params.ct.clone();
1468        tokio::spawn(async move {
1469            external.cancelled().await;
1470            internal.cancel();
1471        });
1472    }
1473
1474    // Signal readiness *after* the router is fully built and external
1475    // shutdown is wired, but *before* run_server takes ownership of
1476    // the listener. The receiver can immediately issue requests.
1477    if let Some(tx) = ready_tx {
1478        // Receiver may have been dropped (test gave up). That's fine.
1479        let _ = tx.send(local_addr);
1480    }
1481
1482    run_server(
1483        router,
1484        listener,
1485        params.tls_paths,
1486        params.mtls_config,
1487        params.shutdown_timeout,
1488        params.auth_state,
1489        params.rbac_swap,
1490        params.on_reload_ready,
1491        params.ct,
1492    )
1493    .await
1494    .map_err(anyhow_to_startup)
1495}
1496
1497/// Emit the standard "listening on …" log lines used by both
1498/// [`serve`] and [`serve_with_listener`].
1499#[allow(
1500    clippy::cognitive_complexity,
1501    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1502)]
1503fn log_listening(name: &str, scheme: &str, addr: &str) {
1504    tracing::info!("{name} listening on {addr}");
1505    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
1506    tracing::info!("  Health check: {scheme}://{addr}/healthz");
1507    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
1508}
1509
1510/// Drive the chosen axum server variant (TLS or plain) with a graceful
1511/// shutdown window. Consumes the router and listener.
1512///
1513/// # Shutdown semantics
1514///
1515/// A single shutdown trigger (the FIRST of: OS signal via
1516/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
1517///
1518/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
1519///    accepting new connections and waits for in-flight requests to
1520///    drain;
1521/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
1522///    drainage exceeds `shutdown_timeout`.
1523///
1524/// Previously this function awaited `shutdown_signal()` independently
1525/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
1526/// resolves once per future and consumes one signal, the force-exit
1527/// timer was tied to a SECOND signal (a second SIGTERM the operator
1528/// would never send). Under a single SIGTERM the graceful drain could
1529/// hang indefinitely. The current implementation derives both branches
1530/// from a single shared trigger so the timeout race is anchored to the
1531/// FIRST (and only) signal.
1532#[allow(
1533    clippy::too_many_arguments,
1534    clippy::cognitive_complexity,
1535    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1536)]
1537async fn run_server(
1538    router: axum::Router,
1539    listener: TcpListener,
1540    tls_paths: Option<(PathBuf, PathBuf)>,
1541    mtls_config: Option<MtlsConfig>,
1542    shutdown_timeout: Duration,
1543    auth_state: Option<Arc<AuthState>>,
1544    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1545    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1546    ct: CancellationToken,
1547) -> anyhow::Result<()> {
1548    // `shutdown_trigger` fires when the FIRST source resolves: either
1549    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
1550    // (which the test harness uses for deterministic shutdown).
1551    let shutdown_trigger = CancellationToken::new();
1552    {
1553        let trigger = shutdown_trigger.clone();
1554        let parent = ct.clone();
1555        tokio::spawn(async move {
1556            tokio::select! {
1557                () = shutdown_signal() => {}
1558                () = parent.cancelled() => {}
1559            }
1560            trigger.cancel();
1561        });
1562    }
1563
1564    let graceful = {
1565        let trigger = shutdown_trigger.clone();
1566        let ct = ct.clone();
1567        async move {
1568            trigger.cancelled().await;
1569            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1570            ct.cancel();
1571        }
1572    };
1573
1574    let force_exit_timer = {
1575        let trigger = shutdown_trigger.clone();
1576        async move {
1577            trigger.cancelled().await;
1578            tokio::time::sleep(shutdown_timeout).await;
1579        }
1580    };
1581
1582    if let Some((cert_path, key_path)) = tls_paths {
1583        let crl_set = if let Some(mtls) = mtls_config.as_ref()
1584            && mtls.crl_enabled
1585        {
1586            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1587            let (crl_set, discover_rx) =
1588                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1589                    .await
1590                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1591            tokio::spawn(mtls_revocation::run_crl_refresher(
1592                Arc::clone(&crl_set),
1593                discover_rx,
1594                ct.clone(),
1595            ));
1596            Some(crl_set)
1597        } else {
1598            None
1599        };
1600
1601        if let Some(cb) = on_reload_ready.take() {
1602            cb(ReloadHandle {
1603                auth: auth_state.clone(),
1604                rbac: Some(Arc::clone(&rbac_swap)),
1605                crl_set: crl_set.clone(),
1606            });
1607        }
1608
1609        let tls_listener = TlsListener::new(
1610            listener,
1611            &cert_path,
1612            &key_path,
1613            mtls_config.as_ref(),
1614            crl_set,
1615        )?;
1616        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1617        tokio::select! {
1618            result = axum::serve(tls_listener, make_svc)
1619                .with_graceful_shutdown(graceful) => { result?; }
1620            () = force_exit_timer => {
1621                tracing::warn!("shutdown timeout exceeded, forcing exit");
1622            }
1623        }
1624    } else {
1625        if let Some(cb) = on_reload_ready.take() {
1626            cb(ReloadHandle {
1627                auth: auth_state,
1628                rbac: Some(rbac_swap),
1629                crl_set: None,
1630            });
1631        }
1632
1633        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1634        tokio::select! {
1635            result = axum::serve(listener, make_svc)
1636                .with_graceful_shutdown(graceful) => { result?; }
1637            () = force_exit_timer => {
1638                tracing::warn!("shutdown timeout exceeded, forcing exit");
1639            }
1640        }
1641    }
1642
1643    Ok(())
1644}
1645
1646/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
1647/// `/register`, and authorization server metadata) on `router`. The
1648/// caller must ensure `oauth_config.proxy` is `Some`.
1649///
1650/// # Errors
1651///
1652/// Returns [`McpxError::Startup`] if the shared
1653/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
1654#[cfg(feature = "oauth")]
1655fn install_oauth_proxy_routes(
1656    router: axum::Router,
1657    server_url: &str,
1658    oauth_config: &crate::oauth::OAuthConfig,
1659    auth_state: Option<&Arc<AuthState>>,
1660) -> Result<axum::Router, McpxError> {
1661    let Some(ref proxy) = oauth_config.proxy else {
1662        return Ok(router);
1663    };
1664
1665    // Single shared HTTP client for all proxy endpoints. Cloning is
1666    // cheap (refcounted) and shares the underlying connection pool.
1667    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1668
1669    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1670    let router = router.route(
1671        "/.well-known/oauth-authorization-server",
1672        axum::routing::get(move || {
1673            let m = asm.clone();
1674            async move { axum::Json(m) }
1675        }),
1676    );
1677
1678    let proxy_authorize = proxy.clone();
1679    let router = router.route(
1680        "/authorize",
1681        axum::routing::get(
1682            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1683                let p = proxy_authorize.clone();
1684                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1685            },
1686        ),
1687    );
1688
1689    let proxy_token = proxy.clone();
1690    let token_http = http.clone();
1691    let router = router.route(
1692        "/token",
1693        axum::routing::post(move |body: String| {
1694            let p = proxy_token.clone();
1695            let h = token_http.clone();
1696            async move { crate::oauth::handle_token(&h, &p, &body).await }
1697        })
1698        .layer(axum::middleware::from_fn(
1699            oauth_token_cache_headers_middleware,
1700        )),
1701    );
1702
1703    let proxy_register = proxy.clone();
1704    let router = router.route(
1705        "/register",
1706        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1707            let p = proxy_register;
1708            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1709        })
1710        .layer(axum::middleware::from_fn(
1711            oauth_token_cache_headers_middleware,
1712        )),
1713    );
1714
1715    let admin_routes_enabled = proxy.expose_admin_endpoints
1716        && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1717    if proxy.expose_admin_endpoints
1718        && !proxy.require_auth_on_admin_endpoints
1719        && proxy.allow_unauthenticated_admin_endpoints
1720    {
1721        // M3 escape-hatch in effect: validate() let this through because
1722        // the operator explicitly opted in. Surface it loudly at startup
1723        // so the choice is auditable in logs.
1724        tracing::warn!(
1725            "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1726             allow_unauthenticated_admin_endpoints opt-out; ensure an \
1727             authenticated reverse proxy fronts these routes"
1728        );
1729    }
1730
1731    let admin_router = if admin_routes_enabled {
1732        build_oauth_admin_router(proxy, http, auth_state)?
1733    } else {
1734        axum::Router::new()
1735    };
1736
1737    let router = router.merge(admin_router);
1738
1739    tracing::info!(
1740        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1741        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1742        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1743    );
1744    Ok(router)
1745}
1746
1747/// Build the optional `/introspect` + `/revoke` admin sub-router.
1748///
1749/// Layered with [`oauth_token_cache_headers_middleware`] so RFC 6749 §5.1
1750/// / RFC 6750 §5.4 cache headers are emitted, and conditionally with the
1751/// auth middleware when `proxy.require_auth_on_admin_endpoints` is set.
1752#[cfg(feature = "oauth")]
1753fn build_oauth_admin_router(
1754    proxy: &crate::oauth::OAuthProxyConfig,
1755    http: crate::oauth::OauthHttpClient,
1756    auth_state: Option<&Arc<AuthState>>,
1757) -> Result<axum::Router, McpxError> {
1758    let mut admin_router = axum::Router::new();
1759    if proxy.introspection_url.is_some() {
1760        let proxy_introspect = proxy.clone();
1761        let introspect_http = http.clone();
1762        admin_router = admin_router.route(
1763            "/introspect",
1764            axum::routing::post(move |body: String| {
1765                let p = proxy_introspect.clone();
1766                let h = introspect_http.clone();
1767                async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1768            }),
1769        );
1770    }
1771    if proxy.revocation_url.is_some() {
1772        let proxy_revoke = proxy.clone();
1773        let revoke_http = http;
1774        admin_router = admin_router.route(
1775            "/revoke",
1776            axum::routing::post(move |body: String| {
1777                let p = proxy_revoke.clone();
1778                let h = revoke_http.clone();
1779                async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1780            }),
1781        );
1782    }
1783
1784    let admin_router = admin_router.layer(axum::middleware::from_fn(
1785        oauth_token_cache_headers_middleware,
1786    ));
1787
1788    if proxy.require_auth_on_admin_endpoints {
1789        let Some(state) = auth_state else {
1790            return Err(McpxError::Startup(
1791                "oauth proxy admin endpoints require auth state".into(),
1792            ));
1793        };
1794        let state_for_mw = Arc::clone(state);
1795        Ok(
1796            admin_router.layer(axum::middleware::from_fn(move |req, next| {
1797                let s = Arc::clone(&state_for_mw);
1798                auth_middleware(s, req, next)
1799            })),
1800        )
1801    } else {
1802        Ok(admin_router)
1803    }
1804}
1805
1806/// Build the host allow-list for rmcp's DNS rebinding protection.
1807///
1808/// Includes loopback hosts by default, then augments with host/authority
1809/// derived from `public_url` and the server bind address.
1810fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1811    let mut hosts = vec![
1812        "localhost".to_owned(),
1813        "127.0.0.1".to_owned(),
1814        "::1".to_owned(),
1815    ];
1816
1817    if let Some(url) = public_url
1818        && let Ok(uri) = url.parse::<axum::http::Uri>()
1819        && let Some(authority) = uri.authority()
1820    {
1821        let host = authority.host().to_owned();
1822        if !hosts.iter().any(|h| h == &host) {
1823            hosts.push(host);
1824        }
1825
1826        let authority = authority.as_str().to_owned();
1827        if !hosts.iter().any(|h| h == &authority) {
1828            hosts.push(authority);
1829        }
1830    }
1831
1832    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1833        && let Some(authority) = uri.authority()
1834    {
1835        let host = authority.host().to_owned();
1836        if !hosts.iter().any(|h| h == &host) {
1837            hosts.push(host);
1838        }
1839
1840        let authority = authority.as_str().to_owned();
1841        if !hosts.iter().any(|h| h == &authority) {
1842            hosts.push(authority);
1843        }
1844    }
1845
1846    hosts
1847}
1848
1849// - TLS support -
1850
1851/// Implement axum's `Connected` trait for `TlsConnInfo` so that
1852/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
1853/// over our custom `TlsListener`.
1854///
1855/// The identity is read directly from the wrapping
1856/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
1857/// between the TLS connection and its mTLS identity. This eliminates the
1858/// previous shared-map approach which was vulnerable to ephemeral-port
1859/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
1860/// pair could alias a stale entry).
1861impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1862    for TlsConnInfo
1863{
1864    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1865        let addr = *target.remote_addr();
1866        let identity = target.io().identity().cloned();
1867        TlsConnInfo::new(addr, identity)
1868    }
1869}
1870
1871/// A TLS-wrapping listener that implements axum's `Listener` trait.
1872///
1873/// When mTLS is configured, verifies client certificates against the
1874/// configured CA and extracts the client identity at handshake time.
1875/// The extracted identity is bound to the connection itself via the
1876/// returned [`AuthenticatedTlsStream`], so it is impossible for an
1877/// unrelated connection to observe it.
1878struct TlsListener {
1879    inner: TcpListener,
1880    acceptor: tokio_rustls::TlsAcceptor,
1881    mtls_default_role: String,
1882}
1883
1884impl TlsListener {
1885    fn new(
1886        inner: TcpListener,
1887        cert_path: &Path,
1888        key_path: &Path,
1889        mtls_config: Option<&MtlsConfig>,
1890        crl_set: Option<Arc<CrlSet>>,
1891    ) -> anyhow::Result<Self> {
1892        // Install the ring crypto provider (ok to call multiple times).
1893        rustls::crypto::ring::default_provider()
1894            .install_default()
1895            .ok();
1896
1897        let certs = load_certs(cert_path)?;
1898        let key = load_key(key_path)?;
1899
1900        let mtls_default_role;
1901
1902        let tls_config = if let Some(mtls) = mtls_config {
1903            mtls_default_role = mtls.default_role.clone();
1904            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1905            {
1906                let Some(crl_set) = crl_set else {
1907                    return Err(anyhow::anyhow!(
1908                        "mTLS CRL verifier requested but CRL state was not initialized"
1909                    ));
1910                };
1911                Arc::new(DynamicClientCertVerifier::new(crl_set))
1912            } else {
1913                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1914                if mtls.required {
1915                    rustls::server::WebPkiClientVerifier::builder(root_store)
1916                        .build()
1917                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1918                } else {
1919                    rustls::server::WebPkiClientVerifier::builder(root_store)
1920                        .allow_unauthenticated()
1921                        .build()
1922                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1923                }
1924            };
1925
1926            tracing::info!(
1927                ca = %mtls.ca_cert_path.display(),
1928                required = mtls.required,
1929                crl_enabled = mtls.crl_enabled,
1930                "mTLS client auth configured"
1931            );
1932
1933            rustls::ServerConfig::builder_with_protocol_versions(&[
1934                &rustls::version::TLS12,
1935                &rustls::version::TLS13,
1936            ])
1937            .with_client_cert_verifier(verifier)
1938            .with_single_cert(certs, key)?
1939        } else {
1940            mtls_default_role = "viewer".to_owned();
1941            rustls::ServerConfig::builder_with_protocol_versions(&[
1942                &rustls::version::TLS12,
1943                &rustls::version::TLS13,
1944            ])
1945            .with_no_client_auth()
1946            .with_single_cert(certs, key)?
1947        };
1948
1949        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1950        tracing::info!(
1951            "TLS enabled (cert: {}, key: {})",
1952            cert_path.display(),
1953            key_path.display()
1954        );
1955        Ok(Self {
1956            inner,
1957            acceptor,
1958            mtls_default_role,
1959        })
1960    }
1961
1962    /// Extract the mTLS client cert identity from a completed TLS handshake.
1963    /// Returns `None` if no client certificate was presented or if the
1964    /// certificate could not be parsed into an [`AuthIdentity`].
1965    fn extract_handshake_identity(
1966        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1967        default_role: &str,
1968        addr: SocketAddr,
1969    ) -> Option<AuthIdentity> {
1970        let (_, server_conn) = tls_stream.get_ref();
1971        let cert_der = server_conn.peer_certificates()?.first()?;
1972        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
1973        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
1974        Some(id)
1975    }
1976}
1977
1978/// A TLS stream paired with the mTLS identity extracted at handshake time.
1979///
1980/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
1981/// identity travels with the connection itself. This replaces the previous
1982/// shared `MtlsIdentities` map, eliminating the
1983/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
1984/// reuse and removing the need for an LRU eviction policy.
1985///
1986/// The wrapper is `Unpin` (its inner stream is `Unpin` because
1987/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
1988/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
1989pub(crate) struct AuthenticatedTlsStream {
1990    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1991    identity: Option<AuthIdentity>,
1992}
1993
1994impl AuthenticatedTlsStream {
1995    /// Returns the verified mTLS client identity, if any.
1996    #[must_use]
1997    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
1998        self.identity.as_ref()
1999    }
2000}
2001
2002impl std::fmt::Debug for AuthenticatedTlsStream {
2003    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2004        f.debug_struct("AuthenticatedTlsStream")
2005            .field("identity", &self.identity.as_ref().map(|id| &id.name))
2006            .finish_non_exhaustive()
2007    }
2008}
2009
2010impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2011    fn poll_read(
2012        mut self: Pin<&mut Self>,
2013        cx: &mut std::task::Context<'_>,
2014        buf: &mut tokio::io::ReadBuf<'_>,
2015    ) -> std::task::Poll<std::io::Result<()>> {
2016        Pin::new(&mut self.inner).poll_read(cx, buf)
2017    }
2018}
2019
2020impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2021    fn poll_write(
2022        mut self: Pin<&mut Self>,
2023        cx: &mut std::task::Context<'_>,
2024        buf: &[u8],
2025    ) -> std::task::Poll<std::io::Result<usize>> {
2026        Pin::new(&mut self.inner).poll_write(cx, buf)
2027    }
2028
2029    fn poll_flush(
2030        mut self: Pin<&mut Self>,
2031        cx: &mut std::task::Context<'_>,
2032    ) -> std::task::Poll<std::io::Result<()>> {
2033        Pin::new(&mut self.inner).poll_flush(cx)
2034    }
2035
2036    fn poll_shutdown(
2037        mut self: Pin<&mut Self>,
2038        cx: &mut std::task::Context<'_>,
2039    ) -> std::task::Poll<std::io::Result<()>> {
2040        Pin::new(&mut self.inner).poll_shutdown(cx)
2041    }
2042
2043    fn poll_write_vectored(
2044        mut self: Pin<&mut Self>,
2045        cx: &mut std::task::Context<'_>,
2046        bufs: &[std::io::IoSlice<'_>],
2047    ) -> std::task::Poll<std::io::Result<usize>> {
2048        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2049    }
2050
2051    fn is_write_vectored(&self) -> bool {
2052        self.inner.is_write_vectored()
2053    }
2054}
2055
2056impl axum::serve::Listener for TlsListener {
2057    type Io = AuthenticatedTlsStream;
2058    type Addr = SocketAddr;
2059
2060    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2061        loop {
2062            let (stream, addr) = match self.inner.accept().await {
2063                Ok(pair) => pair,
2064                Err(e) => {
2065                    tracing::debug!("TCP accept error: {e}");
2066                    continue;
2067                }
2068            };
2069            let tls_stream = match self.acceptor.accept(stream).await {
2070                Ok(s) => s,
2071                Err(e) => {
2072                    tracing::debug!("TLS handshake failed from {addr}: {e}");
2073                    continue;
2074                }
2075            };
2076            let identity =
2077                Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
2078            let wrapped = AuthenticatedTlsStream {
2079                inner: tls_stream,
2080                identity,
2081            };
2082            return (wrapped, addr);
2083        }
2084    }
2085
2086    fn local_addr(&self) -> std::io::Result<Self::Addr> {
2087        self.inner.local_addr()
2088    }
2089}
2090
2091fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2092    use rustls::pki_types::pem::PemObject;
2093    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2094        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2095        .collect::<Result<_, _>>()
2096        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2097    anyhow::ensure!(
2098        !certs.is_empty(),
2099        "no certificates found in {}",
2100        path.display()
2101    );
2102    Ok(certs)
2103}
2104
2105fn load_client_auth_roots(
2106    path: &Path,
2107) -> anyhow::Result<(
2108    Vec<rustls::pki_types::CertificateDer<'static>>,
2109    Arc<RootCertStore>,
2110)> {
2111    let ca_certs = load_certs(path)?;
2112    let mut root_store = RootCertStore::empty();
2113    for cert in &ca_certs {
2114        root_store
2115            .add(cert.clone())
2116            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2117    }
2118
2119    Ok((ca_certs, Arc::new(root_store)))
2120}
2121
2122fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2123    use rustls::pki_types::pem::PemObject;
2124    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2125        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2126}
2127
2128#[allow(clippy::unused_async)]
2129async fn healthz() -> impl IntoResponse {
2130    axum::Json(serde_json::json!({
2131        "status": "ok",
2132    }))
2133}
2134
2135/// Build the `/version` JSON payload for a given server name and version.
2136///
2137/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
2138/// read at compile time from the `MCPX_BUILD_SHA`, `MCPX_BUILD_TIME`, and
2139/// `MCPX_RUSTC_VERSION` env vars. Unset values resolve to `"unknown"`.
2140fn version_payload(name: &str, version: &str) -> serde_json::Value {
2141    serde_json::json!({
2142        "name": name,
2143        "version": version,
2144        "build_git_sha": option_env!("MCPX_BUILD_SHA").unwrap_or("unknown"),
2145        "build_timestamp": option_env!("MCPX_BUILD_TIME").unwrap_or("unknown"),
2146        "rust_version": option_env!("MCPX_RUSTC_VERSION").unwrap_or("unknown"),
2147        "mcpx_version": env!("CARGO_PKG_VERSION"),
2148    })
2149}
2150
2151/// Pre-serialize the `/version` payload to immutable bytes.
2152///
2153/// This is called once at router-build time so per-request handling can
2154/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
2155/// [`serde_json::Value`] on every hit.
2156///
2157/// Serialization of a flat `serde_json::Value` of static-string fields
2158/// cannot fail in practice; the fallback to `b"{}"` exists only to
2159/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
2160fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2161    let value = version_payload(name, version);
2162    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2163}
2164
2165async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2166    let status = check().await;
2167    let ready = status
2168        .get("ready")
2169        .and_then(serde_json::Value::as_bool)
2170        .unwrap_or(false);
2171    let code = if ready {
2172        axum::http::StatusCode::OK
2173    } else {
2174        axum::http::StatusCode::SERVICE_UNAVAILABLE
2175    };
2176    (code, axum::Json(status))
2177}
2178
2179/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2180///
2181/// On non-Unix platforms, only SIGINT is handled.
2182async fn shutdown_signal() {
2183    let ctrl_c = tokio::signal::ctrl_c();
2184
2185    #[cfg(unix)]
2186    {
2187        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2188            Ok(mut term) => {
2189                tokio::select! {
2190                    _ = ctrl_c => {}
2191                    _ = term.recv() => {}
2192                }
2193            }
2194            Err(e) => {
2195                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2196                ctrl_c.await.ok();
2197            }
2198        }
2199    }
2200
2201    #[cfg(not(unix))]
2202    {
2203        ctrl_c.await.ok();
2204    }
2205}
2206
2207// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2208
2209/// Middleware that validates the `Origin` header on incoming HTTP requests.
2210///
2211/// Record HTTP request metrics (method, path, status, duration).
2212#[cfg(feature = "metrics")]
2213async fn metrics_middleware(
2214    metrics: Arc<crate::metrics::McpMetrics>,
2215    req: Request<Body>,
2216    next: Next,
2217) -> axum::response::Response {
2218    let method = req.method().to_string();
2219    let path = req.uri().path().to_owned();
2220    let start = std::time::Instant::now();
2221
2222    let response = next.run(req).await;
2223
2224    let status = response.status().as_u16().to_string();
2225    let duration = start.elapsed().as_secs_f64();
2226
2227    metrics
2228        .http_requests_total
2229        .with_label_values(&[&method, &path, &status])
2230        .inc();
2231    metrics
2232        .http_request_duration_seconds
2233        .with_label_values(&[&method, &path])
2234        .observe(duration);
2235
2236    response
2237}
2238
2239/// OWASP security header hardening applied to every response.
2240///
2241/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2242/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2243/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2244/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2245/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2246///
2247/// Each header's value can be customised via [`SecurityHeadersConfig`]
2248/// on [`McpServerConfig`]. See that type for the three-state semantic
2249/// (`None` = default, `Some("")` = omit, `Some(v)` = override).
2250async fn security_headers_middleware(
2251    is_tls: bool,
2252    cfg: Arc<SecurityHeadersConfig>,
2253    req: Request<Body>,
2254    next: Next,
2255) -> axum::response::Response {
2256    use axum::http::{HeaderName, header};
2257
2258    let mut resp = next.run(req).await;
2259    let headers = resp.headers_mut();
2260
2261    // Strip server identity headers to reduce information leakage.
2262    headers.remove(header::SERVER);
2263    headers.remove(HeaderName::from_static("x-powered-by"));
2264
2265    apply_security_header(
2266        headers,
2267        header::X_CONTENT_TYPE_OPTIONS,
2268        cfg.x_content_type_options.as_deref(),
2269        "nosniff",
2270    );
2271    apply_security_header(
2272        headers,
2273        header::X_FRAME_OPTIONS,
2274        cfg.x_frame_options.as_deref(),
2275        "deny",
2276    );
2277    apply_security_header(
2278        headers,
2279        header::CACHE_CONTROL,
2280        cfg.cache_control.as_deref(),
2281        "no-store, max-age=0",
2282    );
2283    apply_security_header(
2284        headers,
2285        header::REFERRER_POLICY,
2286        cfg.referrer_policy.as_deref(),
2287        "no-referrer",
2288    );
2289    apply_security_header(
2290        headers,
2291        HeaderName::from_static("cross-origin-opener-policy"),
2292        cfg.cross_origin_opener_policy.as_deref(),
2293        "same-origin",
2294    );
2295    apply_security_header(
2296        headers,
2297        HeaderName::from_static("cross-origin-resource-policy"),
2298        cfg.cross_origin_resource_policy.as_deref(),
2299        "same-origin",
2300    );
2301    apply_security_header(
2302        headers,
2303        HeaderName::from_static("cross-origin-embedder-policy"),
2304        cfg.cross_origin_embedder_policy.as_deref(),
2305        "require-corp",
2306    );
2307    apply_security_header(
2308        headers,
2309        HeaderName::from_static("permissions-policy"),
2310        cfg.permissions_policy.as_deref(),
2311        "accelerometer=(), camera=(), geolocation=(), microphone=()",
2312    );
2313    apply_security_header(
2314        headers,
2315        HeaderName::from_static("x-permitted-cross-domain-policies"),
2316        cfg.x_permitted_cross_domain_policies.as_deref(),
2317        "none",
2318    );
2319    apply_security_header(
2320        headers,
2321        HeaderName::from_static("content-security-policy"),
2322        cfg.content_security_policy.as_deref(),
2323        "default-src 'none'; frame-ancestors 'none'",
2324    );
2325    apply_security_header(
2326        headers,
2327        HeaderName::from_static("x-dns-prefetch-control"),
2328        cfg.x_dns_prefetch_control.as_deref(),
2329        "off",
2330    );
2331
2332    if is_tls {
2333        apply_security_header(
2334            headers,
2335            header::STRICT_TRANSPORT_SECURITY,
2336            cfg.strict_transport_security.as_deref(),
2337            "max-age=63072000; includeSubDomains",
2338        );
2339    }
2340
2341    resp
2342}
2343
2344/// Set a single security header on the response, honouring the
2345/// three-state override semantic (None = default, Some("") = omit,
2346/// Some(value) = override).
2347///
2348/// Defence-in-depth: if an override value somehow reaches this point
2349/// despite [`validate_security_headers`] having approved it (e.g. a
2350/// runtime mutation on a non-`Validated` field), we log at error level
2351/// and fall back to the static default rather than panicking. The
2352/// `Validated<McpServerConfig>` type makes that path unreachable in
2353/// well-typed code paths.
2354fn apply_security_header(
2355    headers: &mut axum::http::HeaderMap,
2356    name: axum::http::HeaderName,
2357    override_value: Option<&str>,
2358    default: &'static str,
2359) {
2360    use axum::http::HeaderValue;
2361
2362    match override_value {
2363        None => {
2364            headers.insert(name, HeaderValue::from_static(default));
2365        }
2366        Some("") => {
2367            // Operator explicitly opted out of this header.
2368        }
2369        Some(v) => match HeaderValue::from_str(v) {
2370            Ok(hv) => {
2371                headers.insert(name, hv);
2372            }
2373            Err(err) => {
2374                tracing::error!(
2375                    header = %name,
2376                    error = %err,
2377                    "invalid security header override reached middleware; using default"
2378                );
2379                headers.insert(name, HeaderValue::from_static(default));
2380            }
2381        },
2382    }
2383}
2384
2385/// Validate every non-empty entry in a [`SecurityHeadersConfig`].
2386///
2387/// - `None` and `Some("")` are accepted unconditionally (use-default and
2388///   omit, respectively).
2389/// - `Some(v)` is rejected if `axum::http::HeaderValue::from_str(v)` fails.
2390/// - `strict_transport_security` additionally rejects any value
2391///   containing `preload` (case-insensitive). Operators who genuinely
2392///   want to commit to the HSTS preload list must do so via a future
2393///   explicit `with_hsts_preload(true)` builder, not by smuggling
2394///   `preload` through this knob.
2395fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2396    use axum::http::HeaderValue;
2397
2398    let fields: &[(&str, Option<&str>)] = &[
2399        (
2400            "x_content_type_options",
2401            cfg.x_content_type_options.as_deref(),
2402        ),
2403        ("x_frame_options", cfg.x_frame_options.as_deref()),
2404        ("cache_control", cfg.cache_control.as_deref()),
2405        ("referrer_policy", cfg.referrer_policy.as_deref()),
2406        (
2407            "cross_origin_opener_policy",
2408            cfg.cross_origin_opener_policy.as_deref(),
2409        ),
2410        (
2411            "cross_origin_resource_policy",
2412            cfg.cross_origin_resource_policy.as_deref(),
2413        ),
2414        (
2415            "cross_origin_embedder_policy",
2416            cfg.cross_origin_embedder_policy.as_deref(),
2417        ),
2418        ("permissions_policy", cfg.permissions_policy.as_deref()),
2419        (
2420            "x_permitted_cross_domain_policies",
2421            cfg.x_permitted_cross_domain_policies.as_deref(),
2422        ),
2423        (
2424            "content_security_policy",
2425            cfg.content_security_policy.as_deref(),
2426        ),
2427        (
2428            "x_dns_prefetch_control",
2429            cfg.x_dns_prefetch_control.as_deref(),
2430        ),
2431        (
2432            "strict_transport_security",
2433            cfg.strict_transport_security.as_deref(),
2434        ),
2435    ];
2436
2437    for (field, value) in fields {
2438        let Some(v) = value else { continue };
2439        if v.is_empty() {
2440            continue;
2441        }
2442        if let Err(err) = HeaderValue::from_str(v) {
2443            return Err(McpxError::Config(format!(
2444                "invalid security_headers.{field}: {err}"
2445            )));
2446        }
2447    }
2448
2449    if let Some(v) = cfg.strict_transport_security.as_deref()
2450        && !v.is_empty()
2451        && v.to_ascii_lowercase().contains("preload")
2452    {
2453        return Err(McpxError::Config(format!(
2454            "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2455             HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2456        )));
2457    }
2458
2459    Ok(())
2460}
2461
2462/// Append RFC 6749 §5.1 / RFC 6750 §5.4 cache and `Vary` headers required
2463/// on OAuth token-issuing responses.
2464///
2465/// `Cache-Control: no-store, max-age=0` is already applied globally by
2466/// [`security_headers_middleware`]; this middleware adds:
2467///
2468/// - `Pragma: no-cache` -- mandated by RFC 6749 §5.1 for HTTP/1.0 caches.
2469/// - `Vary: Authorization` -- mandated by RFC 6750 §5.4 for endpoints
2470///   whose response depends on the `Authorization` header.
2471///
2472/// Applied only to the OAuth proxy token-class endpoints (`/token`,
2473/// `/register`, `/introspect`, `/revoke`). `Vary` is appended (not
2474/// inserted) so any `Vary` value already present (e.g. `Accept-Encoding`
2475/// from a compression layer, or `Origin` from a CORS layer) is preserved.
2476#[cfg(feature = "oauth")]
2477async fn oauth_token_cache_headers_middleware(
2478    req: Request<Body>,
2479    next: Next,
2480) -> axum::response::Response {
2481    use axum::http::{HeaderValue, header};
2482
2483    let mut resp = next.run(req).await;
2484    let headers = resp.headers_mut();
2485    headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2486    headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2487    resp
2488}
2489
2490/// Per the MCP spec: if the Origin header is present and its value is not in
2491/// the allowed list, respond with 403 Forbidden. Requests without an Origin
2492/// header are allowed through (e.g. non-browser clients like curl, SDKs).
2493async fn origin_check_middleware(
2494    allowed: Arc<[String]>,
2495    log_request_headers: bool,
2496    req: Request<Body>,
2497    next: Next,
2498) -> axum::response::Response {
2499    let method = req.method().clone();
2500    let path = req.uri().path().to_owned();
2501
2502    log_incoming_request(&method, &path, req.headers(), log_request_headers);
2503
2504    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2505        let origin_str = origin.to_str().unwrap_or("");
2506        if !allowed.iter().any(|a| a == origin_str) {
2507            tracing::warn!(
2508                origin = origin_str,
2509                %method,
2510                %path,
2511                allowed = ?&*allowed,
2512                "rejected request: Origin not allowed"
2513            );
2514            return (
2515                axum::http::StatusCode::FORBIDDEN,
2516                "Forbidden: Origin not allowed",
2517            )
2518                .into_response();
2519        }
2520    }
2521    next.run(req).await
2522}
2523
2524/// Emit a DEBUG log for an incoming request, optionally including the full
2525/// (redacted) header set.
2526fn log_incoming_request(
2527    method: &axum::http::Method,
2528    path: &str,
2529    headers: &axum::http::HeaderMap,
2530    log_request_headers: bool,
2531) {
2532    if log_request_headers {
2533        tracing::debug!(
2534            %method,
2535            %path,
2536            headers = %format_request_headers_for_log(headers),
2537            "incoming request"
2538        );
2539    } else {
2540        tracing::debug!(%method, %path, "incoming request");
2541    }
2542}
2543
2544fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2545    headers
2546        .iter()
2547        .map(|(k, v)| {
2548            let name = k.as_str();
2549            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2550                format!("{name}: [REDACTED]")
2551            } else {
2552                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2553            }
2554        })
2555        .collect::<Vec<_>>()
2556        .join(", ")
2557}
2558
2559// -- stdio transport --
2560
2561/// Serve an MCP server over stdin/stdout (stdio transport).
2562///
2563/// # Security warnings
2564///
2565/// - **No authentication**: the parent process has full, unrestricted access.
2566/// - **No RBAC**: all tools are available regardless of policy.
2567/// - **No TLS**: messages travel over OS pipes in plaintext.
2568/// - **Single client**: only the parent process can connect.
2569/// - **No Origin validation**: not applicable to stdio.
2570///
2571/// Use this only when the MCP client spawns the server as a trusted subprocess
2572/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
2573/// use `serve()` (Streamable HTTP) instead.
2574///
2575/// # Errors
2576///
2577/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
2578/// transport disconnects unexpectedly.
2579// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
2580// macro expansion in this 18-line function (info/warn/info + two matches).
2581// There is nothing meaningful to extract; the allow stays.
2582#[allow(clippy::cognitive_complexity)]
2583pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2584where
2585    H: ServerHandler + 'static,
2586{
2587    use rmcp::ServiceExt as _;
2588
2589    tracing::info!("stdio transport: serving on stdin/stdout");
2590    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2591
2592    let transport = rmcp::transport::io::stdio();
2593
2594    let service = handler
2595        .serve(transport)
2596        .await
2597        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2598
2599    if let Err(e) = service.waiting().await {
2600        tracing::warn!(error = %e, "stdio session ended with error");
2601    }
2602    tracing::info!("stdio session ended");
2603    Ok(())
2604}
2605
2606#[cfg(test)]
2607mod tests {
2608    #![allow(
2609        clippy::unwrap_used,
2610        clippy::expect_used,
2611        clippy::panic,
2612        clippy::indexing_slicing,
2613        clippy::unwrap_in_result,
2614        clippy::print_stdout,
2615        clippy::print_stderr,
2616        deprecated,
2617        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2618    )]
2619    use std::sync::Arc;
2620
2621    use axum::{
2622        body::Body,
2623        http::{Request, StatusCode, header},
2624        response::IntoResponse,
2625    };
2626    use http_body_util::BodyExt;
2627    use tower::ServiceExt as _;
2628
2629    use super::*;
2630
2631    // -- McpServerConfig --
2632
2633    #[test]
2634    fn server_config_new_defaults() {
2635        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2636        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2637        assert_eq!(cfg.name, "test-server");
2638        assert_eq!(cfg.version, "1.0.0");
2639        assert!(cfg.tls_cert_path.is_none());
2640        assert!(cfg.tls_key_path.is_none());
2641        assert!(cfg.auth.is_none());
2642        assert!(cfg.rbac.is_none());
2643        assert!(cfg.allowed_origins.is_empty());
2644        assert!(cfg.tool_rate_limit.is_none());
2645        assert!(cfg.readiness_check.is_none());
2646        assert_eq!(cfg.max_request_body, 1024 * 1024);
2647        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2648        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2649        assert!(!cfg.log_request_headers);
2650    }
2651
2652    #[test]
2653    fn validate_consumes_and_proves() {
2654        // Valid config -> Validated wrapper, original is consumed.
2655        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2656        let validated = cfg.validate().expect("valid config");
2657        // Deref gives read-only access to inner fields.
2658        assert_eq!(validated.name, "test-server");
2659        // into_inner recovers the raw value.
2660        let raw = validated.into_inner();
2661        assert_eq!(raw.name, "test-server");
2662
2663        // Invalid config (zero max_request_body) -> Err.
2664        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2665        bad.max_request_body = 0;
2666        assert!(bad.validate().is_err(), "zero body cap must fail validate");
2667    }
2668
2669    #[test]
2670    fn validate_rejects_zero_max_concurrent_requests() {
2671        let cfg =
2672            McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
2673        let err = cfg.validate().expect_err("zero concurrency cap must fail");
2674        assert!(
2675            format!("{err}").contains("max_concurrent_requests"),
2676            "error should mention max_concurrent_requests, got: {err}"
2677        );
2678    }
2679
2680    #[test]
2681    fn validate_rejects_zero_max_tracked_keys() {
2682        let rl = crate::auth::RateLimitConfig {
2683            max_tracked_keys: 0,
2684            ..Default::default()
2685        };
2686        let auth_cfg = AuthConfig {
2687            enabled: true,
2688            rate_limit: Some(rl),
2689            ..Default::default()
2690        };
2691        let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
2692        let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
2693        assert!(
2694            format!("{err}").contains("max_tracked_keys"),
2695            "error should mention max_tracked_keys, got: {err}"
2696        );
2697    }
2698
2699    #[test]
2700    fn derive_allowed_hosts_includes_public_host() {
2701        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2702        assert!(
2703            hosts.iter().any(|h| h == "mcp.example.com"),
2704            "public_url host must be allowed"
2705        );
2706    }
2707
2708    #[test]
2709    fn derive_allowed_hosts_includes_bind_authority() {
2710        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2711        assert!(
2712            hosts.iter().any(|h| h == "127.0.0.1"),
2713            "bind host must be allowed"
2714        );
2715        assert!(
2716            hosts.iter().any(|h| h == "127.0.0.1:8080"),
2717            "bind authority must be allowed"
2718        );
2719    }
2720
2721    // -- healthz --
2722
2723    #[tokio::test]
2724    async fn healthz_returns_ok_json() {
2725        let resp = healthz().await.into_response();
2726        assert_eq!(resp.status(), StatusCode::OK);
2727        let body = resp.into_body().collect().await.unwrap().to_bytes();
2728        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2729        assert_eq!(json["status"], "ok");
2730        assert!(
2731            json.get("name").is_none(),
2732            "healthz must not expose server name"
2733        );
2734        assert!(
2735            json.get("version").is_none(),
2736            "healthz must not expose version"
2737        );
2738    }
2739
2740    // -- readyz --
2741
2742    #[tokio::test]
2743    async fn readyz_returns_ok_when_ready() {
2744        let check: ReadinessCheck =
2745            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2746        let resp = readyz(check).await.into_response();
2747        assert_eq!(resp.status(), StatusCode::OK);
2748        let body = resp.into_body().collect().await.unwrap().to_bytes();
2749        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2750        assert_eq!(json["ready"], true);
2751        assert!(
2752            json.get("name").is_none(),
2753            "readyz must not expose server name"
2754        );
2755        assert!(
2756            json.get("version").is_none(),
2757            "readyz must not expose version"
2758        );
2759        assert_eq!(json["db"], "connected");
2760    }
2761
2762    #[tokio::test]
2763    async fn readyz_returns_503_when_not_ready() {
2764        let check: ReadinessCheck =
2765            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2766        let resp = readyz(check).await.into_response();
2767        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2768    }
2769
2770    #[tokio::test]
2771    async fn readyz_returns_503_when_ready_missing() {
2772        let check: ReadinessCheck =
2773            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2774        let resp = readyz(check).await.into_response();
2775        // Missing "ready" field defaults to false -> 503
2776        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2777    }
2778
2779    // -- origin_check_middleware --
2780
2781    /// Build a test router with origin check middleware and a simple handler.
2782    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2783        let allowed: Arc<[String]> = Arc::from(origins);
2784        axum::Router::new()
2785            .route("/test", axum::routing::get(|| async { "ok" }))
2786            .layer(axum::middleware::from_fn(move |req, next| {
2787                let a = Arc::clone(&allowed);
2788                origin_check_middleware(a, log_request_headers, req, next)
2789            }))
2790    }
2791
2792    #[tokio::test]
2793    async fn origin_allowed_passes() {
2794        let app = origin_router(vec!["http://localhost:3000".into()], false);
2795        let req = Request::builder()
2796            .uri("/test")
2797            .header(header::ORIGIN, "http://localhost:3000")
2798            .body(Body::empty())
2799            .unwrap();
2800        let resp = app.oneshot(req).await.unwrap();
2801        assert_eq!(resp.status(), StatusCode::OK);
2802    }
2803
2804    #[tokio::test]
2805    async fn origin_rejected_returns_403() {
2806        let app = origin_router(vec!["http://localhost:3000".into()], false);
2807        let req = Request::builder()
2808            .uri("/test")
2809            .header(header::ORIGIN, "http://evil.com")
2810            .body(Body::empty())
2811            .unwrap();
2812        let resp = app.oneshot(req).await.unwrap();
2813        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2814    }
2815
2816    #[tokio::test]
2817    async fn no_origin_header_passes() {
2818        let app = origin_router(vec!["http://localhost:3000".into()], false);
2819        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2820        let resp = app.oneshot(req).await.unwrap();
2821        assert_eq!(resp.status(), StatusCode::OK);
2822    }
2823
2824    #[tokio::test]
2825    async fn empty_allowlist_rejects_any_origin() {
2826        let app = origin_router(vec![], false);
2827        let req = Request::builder()
2828            .uri("/test")
2829            .header(header::ORIGIN, "http://anything.com")
2830            .body(Body::empty())
2831            .unwrap();
2832        let resp = app.oneshot(req).await.unwrap();
2833        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2834    }
2835
2836    #[tokio::test]
2837    async fn empty_allowlist_passes_without_origin() {
2838        let app = origin_router(vec![], false);
2839        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2840        let resp = app.oneshot(req).await.unwrap();
2841        assert_eq!(resp.status(), StatusCode::OK);
2842    }
2843
2844    #[test]
2845    fn format_request_headers_redacts_sensitive_values() {
2846        let mut headers = axum::http::HeaderMap::new();
2847        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2848        headers.insert("cookie", "sid=abc".parse().unwrap());
2849        headers.insert("x-request-id", "req-123".parse().unwrap());
2850
2851        let out = format_request_headers_for_log(&headers);
2852        assert!(out.contains("authorization: [REDACTED]"));
2853        assert!(out.contains("cookie: [REDACTED]"));
2854        assert!(out.contains("x-request-id: req-123"));
2855        assert!(!out.contains("secret-token"));
2856    }
2857
2858    // -- security_headers_middleware --
2859
2860    fn security_router(is_tls: bool) -> axum::Router {
2861        security_router_with(is_tls, SecurityHeadersConfig::default())
2862    }
2863
2864    fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
2865        let cfg = Arc::new(cfg);
2866        axum::Router::new()
2867            .route("/test", axum::routing::get(|| async { "ok" }))
2868            .layer(axum::middleware::from_fn(move |req, next| {
2869                let c = Arc::clone(&cfg);
2870                security_headers_middleware(is_tls, c, req, next)
2871            }))
2872    }
2873
2874    #[tokio::test]
2875    async fn security_headers_set_on_response() {
2876        let app = security_router(false);
2877        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2878        let resp = app.oneshot(req).await.unwrap();
2879        assert_eq!(resp.status(), StatusCode::OK);
2880
2881        let h = resp.headers();
2882        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
2883        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
2884        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
2885        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
2886        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
2887        assert_eq!(
2888            h.get("cross-origin-resource-policy").unwrap(),
2889            "same-origin"
2890        );
2891        assert_eq!(
2892            h.get("cross-origin-embedder-policy").unwrap(),
2893            "require-corp"
2894        );
2895        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
2896        assert!(
2897            h.get("permissions-policy")
2898                .unwrap()
2899                .to_str()
2900                .unwrap()
2901                .contains("camera=()"),
2902            "permissions-policy must restrict browser features"
2903        );
2904        assert_eq!(
2905            h.get("content-security-policy").unwrap(),
2906            "default-src 'none'; frame-ancestors 'none'"
2907        );
2908        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
2909        // No HSTS when TLS is off.
2910        assert!(h.get("strict-transport-security").is_none());
2911    }
2912
2913    #[tokio::test]
2914    async fn hsts_set_when_tls_enabled() {
2915        let app = security_router(true);
2916        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2917        let resp = app.oneshot(req).await.unwrap();
2918
2919        let hsts = resp.headers().get("strict-transport-security").unwrap();
2920        assert!(
2921            hsts.to_str().unwrap().contains("max-age=63072000"),
2922            "HSTS must set 2-year max-age"
2923        );
2924    }
2925
2926    // -- SecurityHeadersConfig validation + override semantics --
2927
2928    /// Build a minimal config with a custom SecurityHeadersConfig and
2929    /// drive it through `check()`. Returns the result so individual
2930    /// tests can assert on success or specific error messages.
2931    fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
2932        let cfg =
2933            McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
2934        cfg.check()
2935    }
2936
2937    #[test]
2938    fn security_headers_config_default_validates() {
2939        check_with_security_headers(SecurityHeadersConfig::default())
2940            .expect("default SecurityHeadersConfig must validate");
2941    }
2942
2943    #[test]
2944    fn security_headers_config_validate_accepts_empty_string() {
2945        // All twelve fields explicitly set to "" -> omit-everything mode.
2946        let h = SecurityHeadersConfig {
2947            x_content_type_options: Some(String::new()),
2948            x_frame_options: Some(String::new()),
2949            cache_control: Some(String::new()),
2950            referrer_policy: Some(String::new()),
2951            cross_origin_opener_policy: Some(String::new()),
2952            cross_origin_resource_policy: Some(String::new()),
2953            cross_origin_embedder_policy: Some(String::new()),
2954            permissions_policy: Some(String::new()),
2955            x_permitted_cross_domain_policies: Some(String::new()),
2956            content_security_policy: Some(String::new()),
2957            x_dns_prefetch_control: Some(String::new()),
2958            strict_transport_security: Some(String::new()),
2959        };
2960        check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
2961    }
2962
2963    #[test]
2964    fn security_headers_config_validate_rejects_bad_value() {
2965        // 0x07 (BEL) is not a valid HTTP header value char.
2966        let h = SecurityHeadersConfig {
2967            referrer_policy: Some("\u{0007}".into()),
2968            ..SecurityHeadersConfig::default()
2969        };
2970        let err = check_with_security_headers(h)
2971            .expect_err("control char in referrer_policy must reject");
2972        let msg = err.to_string();
2973        assert!(
2974            msg.contains("referrer_policy"),
2975            "error must name the offending field, got: {msg}"
2976        );
2977    }
2978
2979    #[test]
2980    fn security_headers_config_validate_rejects_hsts_preload() {
2981        let h = SecurityHeadersConfig {
2982            strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
2983            ..SecurityHeadersConfig::default()
2984        };
2985        let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
2986        let msg = err.to_string();
2987        assert!(
2988            msg.contains("strict_transport_security"),
2989            "error must name the field, got: {msg}"
2990        );
2991        assert!(
2992            msg.to_lowercase().contains("preload"),
2993            "error must mention `preload`, got: {msg}"
2994        );
2995    }
2996
2997    #[test]
2998    fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
2999        // Case-insensitive match.
3000        let h = SecurityHeadersConfig {
3001            strict_transport_security: Some("max-age=600; PRELOAD".into()),
3002            ..SecurityHeadersConfig::default()
3003        };
3004        check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3005    }
3006
3007    #[tokio::test]
3008    async fn security_headers_override_honored() {
3009        // Override X-Frame-Options to SAMEORIGIN.
3010        let h = SecurityHeadersConfig {
3011            x_frame_options: Some("SAMEORIGIN".into()),
3012            ..SecurityHeadersConfig::default()
3013        };
3014        let app = security_router_with(false, h);
3015        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3016        let resp = app.oneshot(req).await.unwrap();
3017        assert_eq!(resp.status(), StatusCode::OK);
3018
3019        let xfo = resp.headers().get("x-frame-options").unwrap();
3020        assert_eq!(xfo, "SAMEORIGIN");
3021    }
3022
3023    #[tokio::test]
3024    async fn security_headers_empty_string_omits() {
3025        // Empty string on referrer-policy -> header absent.
3026        let h = SecurityHeadersConfig {
3027            referrer_policy: Some(String::new()),
3028            ..SecurityHeadersConfig::default()
3029        };
3030        let app = security_router_with(false, h);
3031        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3032        let resp = app.oneshot(req).await.unwrap();
3033        assert_eq!(resp.status(), StatusCode::OK);
3034
3035        assert!(
3036            resp.headers().get("referrer-policy").is_none(),
3037            "Some(\"\") must omit the header"
3038        );
3039        // Other defaults should still be present.
3040        assert_eq!(
3041            resp.headers().get("x-content-type-options").unwrap(),
3042            "nosniff"
3043        );
3044    }
3045
3046    #[tokio::test]
3047    async fn security_headers_hsts_only_when_tls() {
3048        // HSTS override is irrelevant when TLS is off.
3049        let h = SecurityHeadersConfig {
3050            strict_transport_security: Some("max-age=600".into()),
3051            ..SecurityHeadersConfig::default()
3052        };
3053        let app = security_router_with(false, h);
3054        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3055        let resp = app.oneshot(req).await.unwrap();
3056        assert!(
3057            resp.headers().get("strict-transport-security").is_none(),
3058            "HSTS must remain absent on plaintext deployments even with override"
3059        );
3060    }
3061
3062    // -- oauth_token_cache_headers_middleware --
3063
3064    #[cfg(feature = "oauth")]
3065    #[tokio::test]
3066    async fn oauth_token_cache_headers_set_pragma_and_vary() {
3067        let app = axum::Router::new()
3068            .route("/token", axum::routing::post(|| async { "{}" }))
3069            .layer(axum::middleware::from_fn(
3070                oauth_token_cache_headers_middleware,
3071            ));
3072        let req = Request::builder()
3073            .method("POST")
3074            .uri("/token")
3075            .body(Body::from("{}"))
3076            .unwrap();
3077        let resp = app.oneshot(req).await.unwrap();
3078        assert_eq!(resp.status(), StatusCode::OK);
3079
3080        let h = resp.headers();
3081        assert_eq!(
3082            h.get("pragma").unwrap(),
3083            "no-cache",
3084            "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3085        );
3086        let vary_values: Vec<String> = h
3087            .get_all("vary")
3088            .iter()
3089            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3090            .collect();
3091        assert!(
3092            vary_values
3093                .iter()
3094                .any(|v| v.eq_ignore_ascii_case("Authorization")),
3095            "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3096        );
3097    }
3098
3099    #[cfg(feature = "oauth")]
3100    #[tokio::test]
3101    async fn oauth_token_cache_headers_preserve_existing_vary() {
3102        // Simulates a handler/layer that already set `Vary: Accept-Encoding`
3103        // (e.g. compression). Our middleware must APPEND, not REPLACE.
3104        let app = axum::Router::new()
3105            .route(
3106                "/token",
3107                axum::routing::post(|| async {
3108                    axum::response::Response::builder()
3109                        .header("vary", "Accept-Encoding")
3110                        .body(axum::body::Body::from("{}"))
3111                        .unwrap()
3112                }),
3113            )
3114            .layer(axum::middleware::from_fn(
3115                oauth_token_cache_headers_middleware,
3116            ));
3117        let req = Request::builder()
3118            .method("POST")
3119            .uri("/token")
3120            .body(Body::empty())
3121            .unwrap();
3122        let resp = app.oneshot(req).await.unwrap();
3123
3124        let vary: Vec<String> = resp
3125            .headers()
3126            .get_all("vary")
3127            .iter()
3128            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3129            .collect();
3130        assert!(
3131            vary.iter().any(|v| v.contains("Accept-Encoding")),
3132            "must preserve pre-existing Vary value, got {vary:?}"
3133        );
3134        assert!(
3135            vary.iter().any(|v| v.contains("Authorization")),
3136            "must append Authorization to Vary, got {vary:?}"
3137        );
3138    }
3139
3140    // -- version endpoint --
3141
3142    #[test]
3143    fn version_payload_contains_expected_fields() {
3144        let v = version_payload("my-server", "1.2.3");
3145        assert_eq!(v["name"], "my-server");
3146        assert_eq!(v["version"], "1.2.3");
3147        assert!(v["build_git_sha"].is_string());
3148        assert!(v["build_timestamp"].is_string());
3149        assert!(v["rust_version"].is_string());
3150        assert!(v["mcpx_version"].is_string());
3151    }
3152
3153    // -- concurrency limit layer --
3154
3155    #[tokio::test]
3156    async fn concurrency_limit_layer_composes_and_serves() {
3157        // We only assert the layer stack compiles and a single request
3158        // below the cap still succeeds. True back-pressure behaviour
3159        // requires a live HTTP server and is covered by integration tests.
3160        let app = axum::Router::new()
3161            .route("/ok", axum::routing::get(|| async { "ok" }))
3162            .layer(
3163                tower::ServiceBuilder::new()
3164                    .layer(axum::error_handling::HandleErrorLayer::new(
3165                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3166                    ))
3167                    .layer(tower::load_shed::LoadShedLayer::new())
3168                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3169            );
3170        let resp = app
3171            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3172            .await
3173            .unwrap();
3174        assert_eq!(resp.status(), StatusCode::OK);
3175    }
3176
3177    // -- compression layer --
3178
3179    #[tokio::test]
3180    async fn compression_layer_gzip_encodes_response() {
3181        use tower_http::compression::Predicate as _;
3182
3183        let big_body = "a".repeat(4096);
3184        let app = axum::Router::new()
3185            .route(
3186                "/big",
3187                axum::routing::get(move || {
3188                    let body = big_body.clone();
3189                    async move { body }
3190                }),
3191            )
3192            .layer(
3193                tower_http::compression::CompressionLayer::new()
3194                    .gzip(true)
3195                    .br(true)
3196                    .compress_when(
3197                        tower_http::compression::DefaultPredicate::new()
3198                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3199                    ),
3200            );
3201
3202        let req = Request::builder()
3203            .uri("/big")
3204            .header(header::ACCEPT_ENCODING, "gzip")
3205            .body(Body::empty())
3206            .unwrap();
3207        let resp = app.oneshot(req).await.unwrap();
3208        assert_eq!(resp.status(), StatusCode::OK);
3209        assert_eq!(
3210            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3211            "gzip"
3212        );
3213    }
3214}