Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13    ServerHandler,
14    transport::streamable_http_server::{
15        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16    },
17};
18use rustls::RootCertStore;
19use tokio::net::TcpListener;
20use tokio_util::sync::CancellationToken;
21
22use crate::{
23    auth::{
24        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
25        build_rate_limiter, extract_mtls_identity,
26    },
27    error::McpxError,
28    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
29    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
30};
31
32/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
33/// at the public API boundary, flattening the chain via the alternate
34/// formatter so callers see the full causal path.
35#[allow(
36    clippy::needless_pass_by_value,
37    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
38)]
39fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
40    McpxError::Startup(format!("{e:#}"))
41}
42
43/// Map a `std::io::Error` produced during server startup into a public
44/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
45/// `From` impl here because startup-phase IO errors (bind, listener) are
46/// semantically distinct from request-time IO errors and should surface
47/// the originating operation in the message.
48#[allow(
49    clippy::needless_pass_by_value,
50    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
51)]
52fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
53    McpxError::Startup(format!("{op}: {e}"))
54}
55
56/// Async readiness check callback for the `/readyz` endpoint.
57///
58/// Returns a JSON object with at least a `"ready"` boolean.
59/// When `ready` is false, the endpoint returns HTTP 503.
60pub type ReadinessCheck =
61    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
62
63/// Per-header overrides for the OWASP security headers emitted by the
64/// global response middleware.
65///
66/// Each field follows a three-state semantic:
67///
68/// | Value         | Behaviour                                                |
69/// |---------------|----------------------------------------------------------|
70/// | `None`        | Use the built-in default (current behaviour).            |
71/// | `Some("")`    | **Omit** the header entirely from responses.             |
72/// | `Some(value)` | Emit `header: value`. Validated at config-load time.     |
73///
74/// All non-empty values are validated via
75/// [`axum::http::HeaderValue::from_str`] inside
76/// [`McpServerConfig::validate`]; invalid values fail fast before the
77/// server starts accepting traffic.
78///
79/// `Strict-Transport-Security` has an additional rule: the substring
80/// `preload` (case-insensitive) is rejected. Operators who want to
81/// commit to the HSTS preload list must do so via a future explicit
82/// builder method, not by smuggling it through this knob.
83#[derive(Debug, Clone, Default)]
84#[non_exhaustive]
85pub struct SecurityHeadersConfig {
86    /// Override for `X-Content-Type-Options`. Default: `nosniff`.
87    pub x_content_type_options: Option<String>,
88    /// Override for `X-Frame-Options`. Default: `deny`.
89    pub x_frame_options: Option<String>,
90    /// Override for `Cache-Control`. Default: `no-store, max-age=0`.
91    pub cache_control: Option<String>,
92    /// Override for `Referrer-Policy`. Default: `no-referrer`.
93    pub referrer_policy: Option<String>,
94    /// Override for `Cross-Origin-Opener-Policy`. Default: `same-origin`.
95    pub cross_origin_opener_policy: Option<String>,
96    /// Override for `Cross-Origin-Resource-Policy`. Default: `same-origin`.
97    pub cross_origin_resource_policy: Option<String>,
98    /// Override for `Cross-Origin-Embedder-Policy`. Default: `require-corp`.
99    pub cross_origin_embedder_policy: Option<String>,
100    /// Override for `Permissions-Policy`. Default:
101    /// `accelerometer=(), camera=(), geolocation=(), microphone=()`.
102    pub permissions_policy: Option<String>,
103    /// Override for `X-Permitted-Cross-Domain-Policies`. Default: `none`.
104    pub x_permitted_cross_domain_policies: Option<String>,
105    /// Override for `Content-Security-Policy`. Default:
106    /// `default-src 'none'; frame-ancestors 'none'`.
107    pub content_security_policy: Option<String>,
108    /// Override for `X-DNS-Prefetch-Control`. Default: `off`.
109    pub x_dns_prefetch_control: Option<String>,
110    /// Override for `Strict-Transport-Security`. Default (TLS only):
111    /// `max-age=63072000; includeSubDomains`. Only emitted when TLS is
112    /// active; the override is ignored on plaintext deployments. The
113    /// substring `preload` (any case) is rejected by the validator.
114    pub strict_transport_security: Option<String>,
115}
116
117/// Configuration for the MCP server.
118#[allow(
119    missing_debug_implementations,
120    reason = "contains callback/trait objects that don't impl Debug"
121)]
122#[allow(
123    clippy::struct_excessive_bools,
124    reason = "server configuration naturally has many boolean feature flags"
125)]
126#[non_exhaustive]
127pub struct McpServerConfig {
128    /// Socket address the MCP HTTP server binds to.
129    #[deprecated(
130        since = "0.13.0",
131        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
132    )]
133    pub bind_addr: String,
134    /// Server name advertised via MCP `initialize`.
135    #[deprecated(
136        since = "0.13.0",
137        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
138    )]
139    pub name: String,
140    /// Server version advertised via MCP `initialize`.
141    #[deprecated(
142        since = "0.13.0",
143        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
144    )]
145    pub version: String,
146    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
147    #[deprecated(
148        since = "0.13.0",
149        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
150    )]
151    pub tls_cert_path: Option<PathBuf>,
152    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
153    #[deprecated(
154        since = "0.13.0",
155        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
156    )]
157    pub tls_key_path: Option<PathBuf>,
158    /// Optional authentication config. When `Some` and `enabled`, auth
159    /// is enforced on `/mcp`. `/healthz` is always open.
160    #[deprecated(
161        since = "0.13.0",
162        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
163    )]
164    pub auth: Option<AuthConfig>,
165    /// Optional RBAC policy. When present and enabled, tool calls are
166    /// checked against the policy after authentication.
167    #[deprecated(
168        since = "0.13.0",
169        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
170    )]
171    pub rbac: Option<Arc<RbacPolicy>>,
172    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
173    /// When empty and `public_url` is set, the origin is auto-derived from
174    /// the public URL. When both are empty, only requests with no Origin
175    /// header are accepted.
176    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
177    #[deprecated(
178        since = "0.13.0",
179        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
180    )]
181    pub allowed_origins: Vec<String>,
182    /// Maximum tool invocations per source IP per minute.
183    /// When set, enforced on every `tools/call` request.
184    #[deprecated(
185        since = "0.13.0",
186        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
187    )]
188    pub tool_rate_limit: Option<u32>,
189    /// Optional readiness probe for `/readyz`.
190    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
191    #[deprecated(
192        since = "0.13.0",
193        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
194    )]
195    pub readiness_check: Option<ReadinessCheck>,
196    /// Maximum request body size in bytes. Default: 1 MiB.
197    /// Protects against oversized payloads causing OOM.
198    #[deprecated(
199        since = "0.13.0",
200        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
201    )]
202    pub max_request_body: usize,
203    /// Request processing timeout. Default: 120s.
204    /// Requests exceeding this duration receive 408 Request Timeout.
205    #[deprecated(
206        since = "0.13.0",
207        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
208    )]
209    pub request_timeout: Duration,
210    /// Graceful shutdown timeout. Default: 30s.
211    /// After the shutdown signal, in-flight requests have this long to finish.
212    #[deprecated(
213        since = "0.13.0",
214        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
215    )]
216    pub shutdown_timeout: Duration,
217    /// Idle timeout for MCP sessions. Sessions with no activity for this
218    /// duration are closed automatically. Default: 20 minutes.
219    #[deprecated(
220        since = "0.13.0",
221        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
222    )]
223    pub session_idle_timeout: Duration,
224    /// Interval for SSE keep-alive pings. Prevents proxies and load
225    /// balancers from killing idle connections. Default: 15 seconds.
226    #[deprecated(
227        since = "0.13.0",
228        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
229    )]
230    pub sse_keep_alive: Duration,
231    /// Callback invoked once the server is built, delivering a
232    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
233    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
234    #[deprecated(
235        since = "0.13.0",
236        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
237    )]
238    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
239    /// Additional application-specific routes merged into the top-level
240    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
241    /// so the application is responsible for its own auth on them.
242    #[deprecated(
243        since = "0.13.0",
244        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
245    )]
246    pub extra_router: Option<axum::Router>,
247    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
248    /// When set, OAuth metadata endpoints advertise this URL instead of
249    /// the listen address. Required when binding `0.0.0.0` behind a
250    /// reverse proxy or inside a container.
251    #[deprecated(
252        since = "0.13.0",
253        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
254    )]
255    pub public_url: Option<String>,
256    /// Log inbound HTTP request headers at DEBUG level.
257    /// Sensitive values remain redacted.
258    #[deprecated(
259        since = "0.13.0",
260        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
261    )]
262    pub log_request_headers: bool,
263    /// Enable gzip/br response compression on MCP responses.
264    /// Defaults to `false` to preserve existing behaviour.
265    #[deprecated(
266        since = "0.13.0",
267        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
268    )]
269    pub compression_enabled: bool,
270    /// Minimum response body size (in bytes) before compression kicks in.
271    /// Only used when `compression_enabled` is true. Default: 1024.
272    #[deprecated(
273        since = "0.13.0",
274        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
275    )]
276    pub compression_min_size: u16,
277    /// Global cap on in-flight HTTP requests across the whole server.
278    /// When `Some`, requests over the cap receive 503 Service Unavailable
279    /// via `tower::load_shed`. Default: `None` (unlimited).
280    #[deprecated(
281        since = "0.13.0",
282        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
283    )]
284    pub max_concurrent_requests: Option<usize>,
285    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
286    /// configured and `enabled`. Default: `false`.
287    #[deprecated(
288        since = "0.13.0",
289        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
290    )]
291    pub admin_enabled: bool,
292    /// RBAC role required to access admin endpoints. Default: `"admin"`.
293    #[deprecated(
294        since = "0.13.0",
295        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
296    )]
297    pub admin_role: String,
298    /// Enable Prometheus metrics endpoint on a separate listener.
299    /// Requires the `metrics` crate feature.
300    #[cfg(feature = "metrics")]
301    #[deprecated(
302        since = "0.13.0",
303        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
304    )]
305    pub metrics_enabled: bool,
306    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
307    #[cfg(feature = "metrics")]
308    #[deprecated(
309        since = "0.13.0",
310        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
311    )]
312    pub metrics_bind: String,
313    /// Per-header overrides for the OWASP security headers emitted by
314    /// the global response middleware. See [`SecurityHeadersConfig`]
315    /// for the three-state semantic and validation rules.
316    #[deprecated(
317        since = "1.5.0",
318        note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
319    )]
320    pub security_headers: SecurityHeadersConfig,
321}
322
323/// Marker that wraps a value proven to satisfy its validation
324/// contract.
325///
326/// The only way to obtain `Validated<McpServerConfig>` is by calling
327/// [`McpServerConfig::validate`], which is the contract enforced at
328/// the type level by [`serve`] and [`serve_with_listener`]. The
329/// inner field is private, so downstream code cannot bypass
330/// validation by hand-constructing the wrapper.
331///
332/// Use [`Validated::as_inner`] for read-only borrowing. To mutate,
333/// recover the raw value with [`Validated::into_inner`] and
334/// re-validate.
335///
336/// # Example
337///
338/// ```no_run
339/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
340/// use rmcp::handler::server::ServerHandler;
341/// use rmcp::model::{ServerCapabilities, ServerInfo};
342///
343/// #[derive(Clone)]
344/// struct H;
345/// impl ServerHandler for H {
346///     fn get_info(&self) -> ServerInfo {
347///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
348///     }
349/// }
350///
351/// # async fn example() -> rmcp_server_kit::Result<()> {
352/// let config: Validated<McpServerConfig> =
353///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
354/// serve(config, || H).await
355/// # }
356/// ```
357///
358/// Forgetting `.validate()?` is a compile error:
359///
360/// ```compile_fail
361/// use rmcp_server_kit::transport::{McpServerConfig, serve};
362/// use rmcp::handler::server::ServerHandler;
363/// use rmcp::model::{ServerCapabilities, ServerInfo};
364///
365/// #[derive(Clone)]
366/// struct H;
367/// impl ServerHandler for H {
368///     fn get_info(&self) -> ServerInfo {
369///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
370///     }
371/// }
372///
373/// # async fn example() -> rmcp_server_kit::Result<()> {
374/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
375/// // Missing `.validate()?` -> mismatched types: expected
376/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
377/// serve(config, || H).await
378/// # }
379/// ```
380#[allow(
381    missing_debug_implementations,
382    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
383)]
384pub struct Validated<T>(T);
385
386impl<T> std::fmt::Debug for Validated<T> {
387    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        f.debug_struct("Validated").finish_non_exhaustive()
389    }
390}
391
392impl<T> Validated<T> {
393    /// Borrow the inner value.
394    #[must_use]
395    pub fn as_inner(&self) -> &T {
396        &self.0
397    }
398
399    /// Recover the raw value, discarding the validation proof.
400    ///
401    /// Re-validate before re-using the value with [`serve`] or
402    /// [`serve_with_listener`].
403    #[must_use]
404    pub fn into_inner(self) -> T {
405        self.0
406    }
407}
408
409#[allow(
410    deprecated,
411    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
412)]
413impl McpServerConfig {
414    /// Create a new server configuration with the given bind address,
415    /// server name, and version. All other fields use safe defaults.
416    ///
417    /// Use the chainable `with_*` / `enable_*` builder methods to
418    /// customize. Call [`McpServerConfig::validate`] to obtain a
419    /// [`Validated<McpServerConfig>`] proof token, which is required by
420    /// [`serve`] and [`serve_with_listener`].
421    #[must_use]
422    pub fn new(
423        bind_addr: impl Into<String>,
424        name: impl Into<String>,
425        version: impl Into<String>,
426    ) -> Self {
427        Self {
428            bind_addr: bind_addr.into(),
429            name: name.into(),
430            version: version.into(),
431            tls_cert_path: None,
432            tls_key_path: None,
433            auth: None,
434            rbac: None,
435            allowed_origins: Vec::new(),
436            tool_rate_limit: None,
437            readiness_check: None,
438            max_request_body: 1024 * 1024,
439            request_timeout: Duration::from_mins(2),
440            shutdown_timeout: Duration::from_secs(30),
441            session_idle_timeout: Duration::from_mins(20),
442            sse_keep_alive: Duration::from_secs(15),
443            on_reload_ready: None,
444            extra_router: None,
445            public_url: None,
446            log_request_headers: false,
447            compression_enabled: false,
448            compression_min_size: 1024,
449            max_concurrent_requests: None,
450            admin_enabled: false,
451            admin_role: "admin".to_owned(),
452            #[cfg(feature = "metrics")]
453            metrics_enabled: false,
454            #[cfg(feature = "metrics")]
455            metrics_bind: "127.0.0.1:9090".into(),
456            security_headers: SecurityHeadersConfig::default(),
457        }
458    }
459
460    // ---------------------------------------------------------------
461    // Builder methods (fluent, consume + return self).
462    //
463    // Each method is `#[must_use]` because dropping the returned
464    // `McpServerConfig` discards the configuration change.
465    // ---------------------------------------------------------------
466
467    /// Attach an authentication configuration. Required for
468    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
469    #[must_use]
470    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
471        self.auth = Some(auth);
472        self
473    }
474
475    /// Override one or more of the OWASP security headers emitted on
476    /// every response. See [`SecurityHeadersConfig`] for the three-state
477    /// semantic (`None` = default, `Some("")` = omit, `Some(v)` =
478    /// override). Values are validated by [`Self::validate`].
479    #[must_use]
480    pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
481        self.security_headers = headers;
482        self
483    }
484
485    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
486    /// final port is only known after pre-binding an ephemeral listener
487    /// (tests, dynamic-port deployments).
488    #[must_use]
489    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
490        self.bind_addr = addr.into();
491        self
492    }
493
494    /// Attach an RBAC policy. Tool calls are checked against the policy
495    /// after authentication.
496    #[must_use]
497    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
498        self.rbac = Some(rbac);
499        self
500    }
501
502    /// Configure TLS by providing the certificate and private key paths
503    /// (PEM). Both must be readable at startup. Without this call, the
504    /// server runs plain HTTP.
505    #[must_use]
506    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
507        self.tls_cert_path = Some(cert_path.into());
508        self.tls_key_path = Some(key_path.into());
509        self
510    }
511
512    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
513    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
514    /// a container so OAuth metadata and auto-derived origins resolve correctly.
515    #[must_use]
516    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
517        self.public_url = Some(url.into());
518        self
519    }
520
521    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
522    /// When empty and [`with_public_url`](Self::with_public_url) is set,
523    /// the origin is auto-derived.
524    #[must_use]
525    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
526    where
527        I: IntoIterator<Item = S>,
528        S: Into<String>,
529    {
530        self.allowed_origins = origins.into_iter().map(Into::into).collect();
531        self
532    }
533
534    /// Merge an additional axum router at the top level. Routes added
535    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
536    /// for its own protection.
537    #[must_use]
538    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
539        self.extra_router = Some(router);
540        self
541    }
542
543    /// Install an async readiness probe for `/readyz`. Without this call,
544    /// `/readyz` mirrors `/healthz` (always 200 OK).
545    #[must_use]
546    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
547        self.readiness_check = Some(check);
548        self
549    }
550
551    /// Override the maximum request body (bytes). Must be `> 0`.
552    /// Default: 1 MiB.
553    #[must_use]
554    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
555        self.max_request_body = bytes;
556        self
557    }
558
559    /// Override the per-request processing timeout. Default: 2 minutes.
560    #[must_use]
561    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
562        self.request_timeout = timeout;
563        self
564    }
565
566    /// Override the graceful shutdown grace period. Default: 30 seconds.
567    #[must_use]
568    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
569        self.shutdown_timeout = timeout;
570        self
571    }
572
573    /// Override the MCP session idle timeout. Default: 20 minutes.
574    #[must_use]
575    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
576        self.session_idle_timeout = timeout;
577        self
578    }
579
580    /// Override the SSE keep-alive interval. Default: 15 seconds.
581    #[must_use]
582    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
583        self.sse_keep_alive = interval;
584        self
585    }
586
587    /// Cap the global number of in-flight HTTP requests via
588    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
589    /// Default: unlimited.
590    #[must_use]
591    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
592        self.max_concurrent_requests = Some(limit);
593        self
594    }
595
596    /// Cap tool invocations per source IP per minute. Enforced on every
597    /// `tools/call` request.
598    #[must_use]
599    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
600        self.tool_rate_limit = Some(per_minute);
601        self
602    }
603
604    /// Register a callback that receives the [`ReloadHandle`] after the
605    /// server is built. Use it to wire SIGHUP-style hot reloads of API
606    /// keys and RBAC policy.
607    #[must_use]
608    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
609    where
610        F: FnOnce(ReloadHandle) + Send + 'static,
611    {
612        self.on_reload_ready = Some(Box::new(callback));
613        self
614    }
615
616    /// Enable gzip/brotli response compression on MCP responses.
617    /// `min_size` is the smallest body size (bytes) eligible for
618    /// compression. Default min size: 1024.
619    #[must_use]
620    pub fn enable_compression(mut self, min_size: u16) -> Self {
621        self.compression_enabled = true;
622        self.compression_min_size = min_size;
623        self
624    }
625
626    /// Enable `/admin/*` diagnostic endpoints. Requires
627    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
628    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
629    /// role gate (default: `"admin"`).
630    #[must_use]
631    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
632        self.admin_enabled = true;
633        self.admin_role = role.into();
634        self
635    }
636
637    /// Log inbound HTTP request headers at DEBUG level. Sensitive
638    /// values remain redacted by the logging layer.
639    #[must_use]
640    pub fn enable_request_header_logging(mut self) -> Self {
641        self.log_request_headers = true;
642        self
643    }
644
645    /// Enable the Prometheus metrics listener on `bind` (e.g.
646    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
647    #[cfg(feature = "metrics")]
648    #[must_use]
649    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
650        self.metrics_enabled = true;
651        self.metrics_bind = bind.into();
652        self
653    }
654
655    /// Validate the configuration and consume `self`, returning a
656    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
657    /// and [`serve_with_listener`]. This is the only way to construct
658    /// `Validated<McpServerConfig>`, so the type system guarantees
659    /// validation has run before the server starts.
660    ///
661    /// Checks:
662    ///
663    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
664    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
665    ///    be unset.
666    /// 3. `bind_addr` must parse as a [`SocketAddr`].
667    /// 4. `public_url`, when set, must start with `http://` or `https://`.
668    /// 5. Each entry in `allowed_origins` must start with `http://` or
669    ///    `https://`.
670    /// 6. `max_request_body` must be greater than zero.
671    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
672    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
673    ///    `proxy.token_url`, `proxy.introspection_url`,
674    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
675    ///    and use the `https` scheme. Set
676    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
677    ///    targets (strongly discouraged in production - see the field-level
678    ///    docs for the threat model).
679    ///
680    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
681    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
682    ///
683    /// # Errors
684    ///
685    /// Returns [`McpxError::Config`] with a human-readable message on
686    /// the first validation failure.
687    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
688        self.check()?;
689        Ok(Validated(self))
690    }
691
692    /// Run the validation checks without consuming `self`. Used by
693    /// internal call sites (e.g. tests) that need to inspect a config
694    /// without taking ownership.
695    fn check(&self) -> Result<(), McpxError> {
696        // 1. admin <-> auth dependency. Mirrors the runtime check in
697        //    `build_app_router`: admin endpoints require an auth state,
698        //    which is built only when `auth` is `Some` *and* `enabled`.
699        if self.admin_enabled {
700            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
701            if !auth_enabled {
702                return Err(McpxError::Config(
703                    "admin_enabled=true requires auth to be configured and enabled".into(),
704                ));
705            }
706        }
707
708        // 2. TLS cert / key must be paired
709        match (&self.tls_cert_path, &self.tls_key_path) {
710            (Some(_), None) => {
711                return Err(McpxError::Config(
712                    "tls_cert_path is set but tls_key_path is missing".into(),
713                ));
714            }
715            (None, Some(_)) => {
716                return Err(McpxError::Config(
717                    "tls_key_path is set but tls_cert_path is missing".into(),
718                ));
719            }
720            _ => {}
721        }
722
723        // 3. bind_addr parses
724        if self.bind_addr.parse::<SocketAddr>().is_err() {
725            return Err(McpxError::Config(format!(
726                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
727                self.bind_addr
728            )));
729        }
730
731        // 4. public_url scheme
732        if let Some(ref url) = self.public_url
733            && !(url.starts_with("http://") || url.starts_with("https://"))
734        {
735            return Err(McpxError::Config(format!(
736                "public_url {url:?} must start with http:// or https://"
737            )));
738        }
739
740        // 5. allowed_origins scheme
741        for origin in &self.allowed_origins {
742            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
743                return Err(McpxError::Config(format!(
744                    "allowed_origins entry {origin:?} must start with http:// or https://"
745                )));
746            }
747        }
748
749        // 6. max_request_body > 0
750        if self.max_request_body == 0 {
751            return Err(McpxError::Config(
752                "max_request_body must be greater than zero".into(),
753            ));
754        }
755
756        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
757        #[cfg(feature = "oauth")]
758        if let Some(auth_cfg) = &self.auth
759            && let Some(oauth_cfg) = &auth_cfg.oauth
760        {
761            oauth_cfg.validate()?;
762        }
763
764        // 8. Security-header overrides parse as valid HTTP header values,
765        //    and HSTS does not smuggle in a `preload` directive.
766        validate_security_headers(&self.security_headers)?;
767
768        // 9. max_concurrent_requests must be > 0 when set. Zero would
769        //    deadlock the global concurrency limiter and reject every
770        //    request. Mirrors the TOML-side check in `src/config.rs`.
771        if let Some(0) = self.max_concurrent_requests {
772            return Err(McpxError::Config(
773                "max_concurrent_requests must be greater than zero when set".into(),
774            ));
775        }
776
777        // 10. Auth rate-limit `max_tracked_keys` must be > 0. A zero cap
778        //     would force `BoundedKeyedLimiter` to evict on every insert
779        //     and effectively disable rate limiting.
780        if let Some(auth_cfg) = &self.auth
781            && let Some(rl) = &auth_cfg.rate_limit
782            && rl.max_tracked_keys == 0
783        {
784            return Err(McpxError::Config(
785                "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
786            ));
787        }
788
789        Ok(())
790    }
791}
792
793/// Handle for hot-reloading server configuration without restart.
794///
795/// Obtained via [`McpServerConfig::on_reload_ready`].
796/// All swap operations are lock-free and wait-free -- in-flight requests
797/// finish with the old values while new requests see the update immediately.
798#[allow(
799    missing_debug_implementations,
800    reason = "contains Arc<AuthState> with non-Debug fields"
801)]
802pub struct ReloadHandle {
803    auth: Option<Arc<AuthState>>,
804    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
805    crl_set: Option<Arc<CrlSet>>,
806}
807
808impl ReloadHandle {
809    /// Atomically replace the API key list used by the auth middleware.
810    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
811        if let Some(ref auth) = self.auth {
812            auth.reload_keys(keys);
813        }
814    }
815
816    /// Atomically replace the RBAC policy used by the RBAC middleware.
817    pub fn reload_rbac(&self, policy: RbacPolicy) {
818        if let Some(ref rbac) = self.rbac {
819            rbac.store(Arc::new(policy));
820            tracing::info!("RBAC policy reloaded");
821        }
822    }
823
824    /// Force an immediate refresh of all cached mTLS CRLs.
825    ///
826    /// # Errors
827    ///
828    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
829    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
830        let Some(ref crl_set) = self.crl_set else {
831            return Err(McpxError::Config(
832                "CRL refresh requested but mTLS CRL support is not configured".into(),
833            ));
834        };
835
836        crl_set.force_refresh().await
837    }
838}
839
840/// Generic MCP HTTP server.
841///
842/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
843/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
844/// with TLS (rustls). Optionally supports mTLS client certificate auth.
845///
846/// # Errors
847///
848/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
849/// or the server fails.
850// TODO(refactor): cognitive complexity reduced from 111/25 to 83/25 by
851// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
852// Remaining flow is a linear router builder: middleware layering, feature-
853// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
854// would require threading many `&mut Router` helpers and hurt readability
855// of the layer order (which is security-relevant and must stay visible).
856#[allow(
857    clippy::too_many_lines,
858    clippy::cognitive_complexity,
859    reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
860)]
861/// Internal bundle of values produced by [`build_app_router`] and
862/// consumed by [`serve`] / [`serve_with_listener`] when driving the
863/// HTTP listener.
864struct AppRunParams {
865    /// TLS cert/key paths when TLS is configured.
866    tls_paths: Option<(PathBuf, PathBuf)>,
867    /// mTLS configuration when mutual-TLS auth is enabled.
868    mtls_config: Option<MtlsConfig>,
869    /// Graceful shutdown drain window.
870    shutdown_timeout: Duration,
871    /// Shared auth state used by hot-reload callbacks.
872    auth_state: Option<Arc<AuthState>>,
873    /// Hot-reloadable RBAC state used by reload callbacks.
874    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
875    /// Optional callback that receives the final [`ReloadHandle`].
876    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
877    /// Server-internal cancellation token. Cancelled by [`run_server`]
878    /// once the shutdown trigger fires (so rmcp's child token also
879    /// fires, terminating in-flight MCP sessions).
880    ct: CancellationToken,
881    /// `"http"` or `"https"` -- used only for boot-time logging.
882    scheme: &'static str,
883    /// Server name -- used only for boot-time logging.
884    name: String,
885}
886
887/// Build the full application axum [`axum::Router`] (MCP route +
888/// middleware stack + admin + OAuth + health endpoints + security
889/// headers + CORS + compression + concurrency limit + origin check)
890/// and the [`AppRunParams`] needed to drive it.
891///
892/// This is the shared core of [`serve`] and [`serve_with_listener`].
893/// It performs *no* network I/O: callers are responsible for binding
894/// (or accepting a pre-bound) [`TcpListener`] and invoking
895/// [`run_server`].
896#[allow(
897    clippy::cognitive_complexity,
898    reason = "router assembly is intrinsically sequential; splitting harms readability"
899)]
900#[allow(
901    deprecated,
902    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
903)]
904fn build_app_router<H, F>(
905    mut config: McpServerConfig,
906    handler_factory: F,
907) -> anyhow::Result<(axum::Router, AppRunParams)>
908where
909    H: ServerHandler + 'static,
910    F: Fn() -> H + Send + Sync + Clone + 'static,
911{
912    let ct = CancellationToken::new();
913
914    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
915    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
916
917    let mcp_service = StreamableHttpService::new(
918        move || Ok(handler_factory()),
919        {
920            let mut mgr = LocalSessionManager::default();
921            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
922            mgr.into()
923        },
924        StreamableHttpServerConfig::default()
925            .with_allowed_hosts(allowed_hosts)
926            .with_sse_keep_alive(Some(config.sse_keep_alive))
927            .with_cancellation_token(ct.child_token()),
928    );
929
930    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
931    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
932
933    // Build auth state eagerly when auth is configured so we can wire both
934    // the auth middleware *and* the optional admin router against the same
935    // state. The middleware itself is installed further down in layer order.
936    let auth_state: Option<Arc<AuthState>> = match config.auth {
937        Some(ref auth_config) if auth_config.enabled => {
938            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
939            let pre_auth_limiter = auth_config
940                .rate_limit
941                .as_ref()
942                .map(crate::auth::build_pre_auth_limiter);
943
944            #[cfg(feature = "oauth")]
945            let jwks_cache = auth_config
946                .oauth
947                .as_ref()
948                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
949                .transpose()
950                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
951
952            Some(Arc::new(AuthState {
953                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
954                rate_limiter,
955                pre_auth_limiter,
956                #[cfg(feature = "oauth")]
957                jwks_cache,
958                seen_identities: std::sync::Mutex::new(std::collections::HashSet::new()),
959                counters: crate::auth::AuthCounters::default(),
960            }))
961        }
962        _ => None,
963    };
964
965    // Build the RBAC policy swap early so the admin router and the later
966    // RBAC middleware layer share the same hot-reloadable state.
967    let rbac_swap = Arc::new(ArcSwap::new(
968        config
969            .rbac
970            .clone()
971            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
972    ));
973
974    // Optional /admin/* diagnostic routes. Merged BEFORE the
975    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
976    if config.admin_enabled {
977        let Some(ref auth_state_ref) = auth_state else {
978            return Err(anyhow::anyhow!(
979                "admin_enabled=true requires auth to be configured and enabled"
980            ));
981        };
982        let admin_state = crate::admin::AdminState {
983            started_at: std::time::Instant::now(),
984            name: config.name.clone(),
985            version: config.version.clone(),
986            auth: Some(Arc::clone(auth_state_ref)),
987            rbac: Arc::clone(&rbac_swap),
988        };
989        let admin_cfg = crate::admin::AdminConfig {
990            role: config.admin_role.clone(),
991        };
992        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
993        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
994    }
995
996    // ----- Middleware order (CRITICAL: read carefully) ------------------
997    //
998    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
999    // added is the OUTERMOST (runs first on a request). To achieve a
1000    // request-time flow of:
1001    //
1002    //   body-limit -> timeout -> auth -> rbac -> handler
1003    //
1004    // we add layers in the REVERSE order:
1005    //
1006    //   1. RBAC               (innermost, runs last before handler)
1007    //   2. auth               (parses identity, sets extension for RBAC)
1008    //   3. timeout            (bounds total request time)
1009    //   4. body-limit         (outermost on /mcp; caps payload before
1010    //                          anything else reads/buffers it)
1011    //
1012    // Origin validation is installed on the OUTER router (after the
1013    // /mcp router is merged in), so it also protects /healthz, /readyz,
1014    // /version, and any OAuth proxy endpoints.
1015    //
1016    // Rationale:
1017    // - Body-limit must be outermost on /mcp so RBAC (which reads the
1018    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
1019    // - Auth must run before RBAC because RBAC consumes
1020    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
1021    //   policy.
1022    // - Origin runs before auth so we reject cross-origin requests
1023    //   without spending Argon2 cycles on unauthenticated callers.
1024
1025    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
1026    // Always installed: even when RBAC is disabled, tool rate limiting may
1027    // be active (MCP spec: servers MUST rate limit tool invocations).
1028    {
1029        let tool_limiter: Option<Arc<ToolRateLimiter>> =
1030            config.tool_rate_limit.map(build_tool_rate_limiter);
1031
1032        if rbac_swap.load().is_enabled() {
1033            tracing::info!("RBAC enforcement enabled on /mcp");
1034        }
1035        if let Some(limit) = config.tool_rate_limit {
1036            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1037        }
1038
1039        let rbac_for_mw = Arc::clone(&rbac_swap);
1040        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1041            let p = rbac_for_mw.load_full();
1042            let tl = tool_limiter.clone();
1043            rbac_middleware(p, tl, req, next)
1044        }));
1045    }
1046
1047    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
1048    if let Some(ref auth_config) = config.auth
1049        && auth_config.enabled
1050    {
1051        let Some(ref state) = auth_state else {
1052            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1053        };
1054
1055        let methods: Vec<&str> = [
1056            auth_config.mtls.is_some().then_some("mTLS"),
1057            (!auth_config.api_keys.is_empty()).then_some("bearer"),
1058            #[cfg(feature = "oauth")]
1059            auth_config.oauth.is_some().then_some("oauth-jwt"),
1060        ]
1061        .into_iter()
1062        .flatten()
1063        .collect();
1064
1065        tracing::info!(
1066            methods = %methods.join(", "),
1067            api_keys = auth_config.api_keys.len(),
1068            "auth enabled on /mcp"
1069        );
1070
1071        let state_for_mw = Arc::clone(state);
1072        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1073            let s = Arc::clone(&state_for_mw);
1074            auth_middleware(s, req, next)
1075        }));
1076    }
1077
1078    // [3] Request timeout (returns 408 on expiry). Bounds total request
1079    // duration including auth + handler.
1080    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1081        axum::http::StatusCode::REQUEST_TIMEOUT,
1082        config.request_timeout,
1083    ));
1084
1085    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
1086    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
1087    // attempts to buffer or parse the body.
1088    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1089        config.max_request_body,
1090    ));
1091
1092    // Compute the effective allowed-origins list for the outer
1093    // origin-check layer (installed on the merged router below). When
1094    // `allowed_origins` is empty but `public_url` is set, auto-derive
1095    // the origin from the public URL so MCP clients (e.g. Claude Code)
1096    // that send `Origin: <server-url>` are accepted without explicit
1097    // config.
1098    let mut effective_origins = config.allowed_origins.clone();
1099    if effective_origins.is_empty()
1100        && let Some(ref url) = config.public_url
1101    {
1102        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1103        // Strip any path/query from the public URL.
1104        if let Some(scheme_end) = url.find("://") {
1105            let after_scheme = &url[scheme_end + 3..];
1106            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1107            let origin = format!("{}{}", &url[..scheme_end + 3], &after_scheme[..host_end]);
1108            tracing::info!(
1109                %origin,
1110                "auto-derived allowed origin from public_url"
1111            );
1112            effective_origins.push(origin);
1113        }
1114    }
1115    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1116    let cors_origins = Arc::clone(&allowed_origins);
1117    let log_request_headers = config.log_request_headers;
1118
1119    let readyz_route = if let Some(check) = config.readiness_check.take() {
1120        axum::routing::get(move || readyz(Arc::clone(&check)))
1121    } else {
1122        axum::routing::get(healthz)
1123    };
1124
1125    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1126    let mut router = axum::Router::new()
1127        .route("/healthz", axum::routing::get(healthz))
1128        .route("/readyz", readyz_route)
1129        .route(
1130            "/version",
1131            axum::routing::get({
1132                // Pre-serialize the version payload once at router-build
1133                // time. The handler then serves a cheap `Arc::clone` of the
1134                // immutable bytes per request, avoiding `serde_json::Value`
1135                // allocation + serialization on every `/version` hit.
1136                let payload_bytes: Arc<[u8]> =
1137                    serialize_version_payload(&config.name, &config.version);
1138                move || {
1139                    let p = Arc::clone(&payload_bytes);
1140                    async move {
1141                        (
1142                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1143                            p.to_vec(),
1144                        )
1145                    }
1146                }
1147            }),
1148        )
1149        .merge(mcp_router);
1150
1151    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1152    if let Some(extra) = config.extra_router.take() {
1153        router = router.merge(extra);
1154    }
1155
1156    // RFC 9728: Protected Resource Metadata endpoint.
1157    // When OAuth is configured, serve full metadata with authorization_servers.
1158    // Otherwise, serve a minimal document with just the resource URL and no
1159    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1160    // that the server exists but does NOT require OAuth authentication,
1161    // preventing them from gating the connection behind a broken auth flow.
1162    let server_url = if let Some(ref url) = config.public_url {
1163        url.trim_end_matches('/').to_owned()
1164    } else {
1165        let prm_scheme = if config.tls_cert_path.is_some() {
1166            "https"
1167        } else {
1168            "http"
1169        };
1170        format!("{prm_scheme}://{}", config.bind_addr)
1171    };
1172    let resource_url = format!("{server_url}/mcp");
1173
1174    #[cfg(feature = "oauth")]
1175    let prm_metadata = if let Some(ref auth_config) = config.auth
1176        && let Some(ref oauth_config) = auth_config.oauth
1177    {
1178        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1179    } else {
1180        serde_json::json!({ "resource": resource_url })
1181    };
1182    #[cfg(not(feature = "oauth"))]
1183    let prm_metadata = serde_json::json!({ "resource": resource_url });
1184
1185    router = router.route(
1186        "/.well-known/oauth-protected-resource",
1187        axum::routing::get(move || {
1188            let m = prm_metadata.clone();
1189            async move { axum::Json(m) }
1190        }),
1191    );
1192
1193    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1194    // /authorize, /token, /register, and authorization server metadata so
1195    // MCP clients can perform Authorization Code + PKCE against the upstream
1196    // IdP (e.g. Keycloak) transparently.
1197    #[cfg(feature = "oauth")]
1198    if let Some(ref auth_config) = config.auth
1199        && let Some(ref oauth_config) = auth_config.oauth
1200        && oauth_config.proxy.is_some()
1201    {
1202        router =
1203            install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1204    }
1205
1206    // OWASP security response headers (applied to all responses).
1207    // HSTS is conditional on TLS being configured.
1208    let is_tls = config.tls_cert_path.is_some();
1209    let security_headers_cfg = Arc::new(config.security_headers.clone());
1210    router = router.layer(axum::middleware::from_fn(move |req, next| {
1211        let cfg = Arc::clone(&security_headers_cfg);
1212        security_headers_middleware(is_tls, cfg, req, next)
1213    }));
1214
1215    // CORS preflight layer (required for browser-based MCP clients).
1216    // Uses the same effective origins as the origin check middleware
1217    // (including auto-derived origin from public_url).
1218    if !cors_origins.is_empty() {
1219        let cors = tower_http::cors::CorsLayer::new()
1220            .allow_origin(
1221                cors_origins
1222                    .iter()
1223                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1224                    .collect::<Vec<_>>(),
1225            )
1226            .allow_methods([
1227                axum::http::Method::GET,
1228                axum::http::Method::POST,
1229                axum::http::Method::OPTIONS,
1230            ])
1231            .allow_headers([
1232                axum::http::header::CONTENT_TYPE,
1233                axum::http::header::AUTHORIZATION,
1234            ]);
1235        router = router.layer(cors);
1236    }
1237
1238    // Optional response compression (gzip + brotli). Skips small bodies
1239    // to avoid overhead. Applied after CORS so preflight responses remain
1240    // uncompressed.
1241    if config.compression_enabled {
1242        use tower_http::compression::Predicate as _;
1243        let predicate = tower_http::compression::DefaultPredicate::new().and(
1244            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1245        );
1246        router = router.layer(
1247            tower_http::compression::CompressionLayer::new()
1248                .gzip(true)
1249                .br(true)
1250                .compress_when(predicate),
1251        );
1252        tracing::info!(
1253            min_size = config.compression_min_size,
1254            "response compression enabled (gzip, br)"
1255        );
1256    }
1257
1258    // Optional global concurrency cap. `load_shed` converts the
1259    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1260    if let Some(max) = config.max_concurrent_requests {
1261        let overload_handler = tower::ServiceBuilder::new()
1262            .layer(axum::error_handling::HandleErrorLayer::new(
1263                |_err: tower::BoxError| async {
1264                    (
1265                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1266                        axum::Json(serde_json::json!({
1267                            "error": "overloaded",
1268                            "error_description": "server is at capacity, retry later"
1269                        })),
1270                    )
1271                },
1272            ))
1273            .layer(tower::load_shed::LoadShedLayer::new())
1274            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1275        router = router.layer(overload_handler);
1276        tracing::info!(max, "global concurrency limit enabled");
1277    }
1278
1279    // JSON fallback for unmatched routes. Without this, axum returns
1280    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1281    // when they probe OAuth endpoints like /authorize or /token.
1282    router = router.fallback(|| async {
1283        (
1284            axum::http::StatusCode::NOT_FOUND,
1285            axum::Json(serde_json::json!({
1286                "error": "not_found",
1287                "error_description": "The requested endpoint does not exist"
1288            })),
1289        )
1290    });
1291
1292    // Prometheus metrics: recording middleware + separate listener.
1293    #[cfg(feature = "metrics")]
1294    if config.metrics_enabled {
1295        let metrics = Arc::new(
1296            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1297        );
1298        let m = Arc::clone(&metrics);
1299        router = router.layer(axum::middleware::from_fn(
1300            move |req: Request<Body>, next: Next| {
1301                let m = Arc::clone(&m);
1302                metrics_middleware(m, req, next)
1303            },
1304        ));
1305        let metrics_bind = config.metrics_bind.clone();
1306        let metrics_shutdown = ct.clone();
1307        tokio::spawn(async move {
1308            if let Err(e) =
1309                crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1310            {
1311                tracing::error!("metrics listener failed: {e}");
1312            }
1313        });
1314    }
1315
1316    // Origin validation layer (MCP spec: servers MUST validate the
1317    // Origin header to prevent DNS rebinding attacks). Installed as the
1318    // OUTERMOST layer on the OUTER router so it protects ALL routes
1319    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1320    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1321    // reject cross-origin attackers without spending Argon2 cycles.
1322    //
1323    // Origin-less requests (e.g. server-to-server probes, curl, native
1324    // MCP clients) are permitted; only requests with an Origin header
1325    // that does not match `effective_origins` are rejected.
1326    router = router.layer(axum::middleware::from_fn(move |req, next| {
1327        let origins = Arc::clone(&allowed_origins);
1328        origin_check_middleware(origins, log_request_headers, req, next)
1329    }));
1330
1331    let scheme = if config.tls_cert_path.is_some() {
1332        "https"
1333    } else {
1334        "http"
1335    };
1336
1337    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1338        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1339        _ => None,
1340    };
1341    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1342
1343    Ok((
1344        router,
1345        AppRunParams {
1346            tls_paths,
1347            mtls_config,
1348            shutdown_timeout: config.shutdown_timeout,
1349            auth_state,
1350            rbac_swap,
1351            on_reload_ready: config.on_reload_ready.take(),
1352            ct,
1353            scheme,
1354            name: config.name.clone(),
1355        },
1356    ))
1357}
1358
1359/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1360/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1361///
1362/// This is the standard entry point for production deployments. For
1363/// deterministic shutdown control (e.g. integration tests), see
1364/// [`serve_with_listener`].
1365///
1366/// The configuration must be validated first via
1367/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1368/// token. This typestate guarantees, at compile time, that the server
1369/// never starts with an invalid configuration.
1370///
1371/// # Errors
1372///
1373/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1374/// fails, or if the underlying axum server returns an error.
1375pub async fn serve<H, F>(
1376    config: Validated<McpServerConfig>,
1377    handler_factory: F,
1378) -> Result<(), McpxError>
1379where
1380    H: ServerHandler + 'static,
1381    F: Fn() -> H + Send + Sync + Clone + 'static,
1382{
1383    let config = config.into_inner();
1384    #[allow(
1385        deprecated,
1386        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1387    )]
1388    let bind_addr = config.bind_addr.clone();
1389    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1390
1391    let listener = TcpListener::bind(&bind_addr)
1392        .await
1393        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1394    log_listening(&params.name, params.scheme, &bind_addr);
1395
1396    run_server(
1397        router,
1398        listener,
1399        params.tls_paths,
1400        params.mtls_config,
1401        params.shutdown_timeout,
1402        params.auth_state,
1403        params.rbac_swap,
1404        params.on_reload_ready,
1405        params.ct,
1406    )
1407    .await
1408    .map_err(anyhow_to_startup)
1409}
1410
1411/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1412/// readiness signalling and external shutdown control.
1413///
1414/// This variant is intended for **deterministic integration tests** and
1415/// for embedders that need to bind the listening socket themselves
1416/// (e.g. systemd socket activation). Compared to [`serve`]:
1417///
1418/// * The caller passes a `TcpListener` that is already bound. This
1419///   eliminates the bind race in tests that previously required
1420///   poll-the-`/healthz`-loop start-up detection.
1421/// * `ready_tx`, when `Some`, receives the socket's
1422///   [`SocketAddr`] *after* the router is built and immediately before
1423///   the server starts accepting connections. Tests can `await` the
1424///   matching `oneshot::Receiver` to know exactly when it is safe to
1425///   issue requests.
1426/// * `shutdown`, when `Some`, gives the caller a
1427///   [`CancellationToken`] that triggers the same graceful-shutdown
1428///   path as a real OS signal. This avoids cross-platform issues with
1429///   sending real `SIGTERM` from tests on Windows.
1430///
1431/// All three optional parameters degrade gracefully: if `ready_tx` is
1432/// `None`, no signal is sent; if `shutdown` is `None`, the server only
1433/// stops on an OS signal (just like [`serve`]).
1434///
1435/// # Errors
1436///
1437/// Returns [`McpxError::Startup`] if router construction fails, if reading
1438/// the listener's `local_addr()` fails, or if the underlying axum
1439/// server returns an error.
1440pub async fn serve_with_listener<H, F>(
1441    listener: TcpListener,
1442    config: Validated<McpServerConfig>,
1443    handler_factory: F,
1444    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1445    shutdown: Option<CancellationToken>,
1446) -> Result<(), McpxError>
1447where
1448    H: ServerHandler + 'static,
1449    F: Fn() -> H + Send + Sync + Clone + 'static,
1450{
1451    let config = config.into_inner();
1452    let local_addr = listener
1453        .local_addr()
1454        .map_err(|e| io_to_startup("listener.local_addr", e))?;
1455    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1456
1457    log_listening(&params.name, params.scheme, &local_addr.to_string());
1458
1459    // Forward external shutdown into the server-internal cancellation
1460    // token so `run_server`'s shutdown trigger picks it up alongside
1461    // any real OS signal.
1462    if let Some(external) = shutdown {
1463        let internal = params.ct.clone();
1464        tokio::spawn(async move {
1465            external.cancelled().await;
1466            internal.cancel();
1467        });
1468    }
1469
1470    // Signal readiness *after* the router is fully built and external
1471    // shutdown is wired, but *before* run_server takes ownership of
1472    // the listener. The receiver can immediately issue requests.
1473    if let Some(tx) = ready_tx {
1474        // Receiver may have been dropped (test gave up). That's fine.
1475        let _ = tx.send(local_addr);
1476    }
1477
1478    run_server(
1479        router,
1480        listener,
1481        params.tls_paths,
1482        params.mtls_config,
1483        params.shutdown_timeout,
1484        params.auth_state,
1485        params.rbac_swap,
1486        params.on_reload_ready,
1487        params.ct,
1488    )
1489    .await
1490    .map_err(anyhow_to_startup)
1491}
1492
1493/// Emit the standard "listening on …" log lines used by both
1494/// [`serve`] and [`serve_with_listener`].
1495#[allow(
1496    clippy::cognitive_complexity,
1497    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1498)]
1499fn log_listening(name: &str, scheme: &str, addr: &str) {
1500    tracing::info!("{name} listening on {addr}");
1501    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
1502    tracing::info!("  Health check: {scheme}://{addr}/healthz");
1503    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
1504}
1505
1506/// Drive the chosen axum server variant (TLS or plain) with a graceful
1507/// shutdown window. Consumes the router and listener.
1508///
1509/// # Shutdown semantics
1510///
1511/// A single shutdown trigger (the FIRST of: OS signal via
1512/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
1513///
1514/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
1515///    accepting new connections and waits for in-flight requests to
1516///    drain;
1517/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
1518///    drainage exceeds `shutdown_timeout`.
1519///
1520/// Previously this function awaited `shutdown_signal()` independently
1521/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
1522/// resolves once per future and consumes one signal, the force-exit
1523/// timer was tied to a SECOND signal (a second SIGTERM the operator
1524/// would never send). Under a single SIGTERM the graceful drain could
1525/// hang indefinitely. The current implementation derives both branches
1526/// from a single shared trigger so the timeout race is anchored to the
1527/// FIRST (and only) signal.
1528#[allow(
1529    clippy::too_many_arguments,
1530    clippy::cognitive_complexity,
1531    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1532)]
1533async fn run_server(
1534    router: axum::Router,
1535    listener: TcpListener,
1536    tls_paths: Option<(PathBuf, PathBuf)>,
1537    mtls_config: Option<MtlsConfig>,
1538    shutdown_timeout: Duration,
1539    auth_state: Option<Arc<AuthState>>,
1540    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1541    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1542    ct: CancellationToken,
1543) -> anyhow::Result<()> {
1544    // `shutdown_trigger` fires when the FIRST source resolves: either
1545    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
1546    // (which the test harness uses for deterministic shutdown).
1547    let shutdown_trigger = CancellationToken::new();
1548    {
1549        let trigger = shutdown_trigger.clone();
1550        let parent = ct.clone();
1551        tokio::spawn(async move {
1552            tokio::select! {
1553                () = shutdown_signal() => {}
1554                () = parent.cancelled() => {}
1555            }
1556            trigger.cancel();
1557        });
1558    }
1559
1560    let graceful = {
1561        let trigger = shutdown_trigger.clone();
1562        let ct = ct.clone();
1563        async move {
1564            trigger.cancelled().await;
1565            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1566            ct.cancel();
1567        }
1568    };
1569
1570    let force_exit_timer = {
1571        let trigger = shutdown_trigger.clone();
1572        async move {
1573            trigger.cancelled().await;
1574            tokio::time::sleep(shutdown_timeout).await;
1575        }
1576    };
1577
1578    if let Some((cert_path, key_path)) = tls_paths {
1579        let crl_set = if let Some(mtls) = mtls_config.as_ref()
1580            && mtls.crl_enabled
1581        {
1582            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1583            let (crl_set, discover_rx) =
1584                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1585                    .await
1586                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1587            tokio::spawn(mtls_revocation::run_crl_refresher(
1588                Arc::clone(&crl_set),
1589                discover_rx,
1590                ct.clone(),
1591            ));
1592            Some(crl_set)
1593        } else {
1594            None
1595        };
1596
1597        if let Some(cb) = on_reload_ready.take() {
1598            cb(ReloadHandle {
1599                auth: auth_state.clone(),
1600                rbac: Some(Arc::clone(&rbac_swap)),
1601                crl_set: crl_set.clone(),
1602            });
1603        }
1604
1605        let tls_listener = TlsListener::new(
1606            listener,
1607            &cert_path,
1608            &key_path,
1609            mtls_config.as_ref(),
1610            crl_set,
1611        )?;
1612        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1613        tokio::select! {
1614            result = axum::serve(tls_listener, make_svc)
1615                .with_graceful_shutdown(graceful) => { result?; }
1616            () = force_exit_timer => {
1617                tracing::warn!("shutdown timeout exceeded, forcing exit");
1618            }
1619        }
1620    } else {
1621        if let Some(cb) = on_reload_ready.take() {
1622            cb(ReloadHandle {
1623                auth: auth_state,
1624                rbac: Some(rbac_swap),
1625                crl_set: None,
1626            });
1627        }
1628
1629        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1630        tokio::select! {
1631            result = axum::serve(listener, make_svc)
1632                .with_graceful_shutdown(graceful) => { result?; }
1633            () = force_exit_timer => {
1634                tracing::warn!("shutdown timeout exceeded, forcing exit");
1635            }
1636        }
1637    }
1638
1639    Ok(())
1640}
1641
1642/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
1643/// `/register`, and authorization server metadata) on `router`. The
1644/// caller must ensure `oauth_config.proxy` is `Some`.
1645///
1646/// # Errors
1647///
1648/// Returns [`McpxError::Startup`] if the shared
1649/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
1650#[cfg(feature = "oauth")]
1651fn install_oauth_proxy_routes(
1652    router: axum::Router,
1653    server_url: &str,
1654    oauth_config: &crate::oauth::OAuthConfig,
1655    auth_state: Option<&Arc<AuthState>>,
1656) -> Result<axum::Router, McpxError> {
1657    let Some(ref proxy) = oauth_config.proxy else {
1658        return Ok(router);
1659    };
1660
1661    // Single shared HTTP client for all proxy endpoints. Cloning is
1662    // cheap (refcounted) and shares the underlying connection pool.
1663    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1664
1665    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1666    let router = router.route(
1667        "/.well-known/oauth-authorization-server",
1668        axum::routing::get(move || {
1669            let m = asm.clone();
1670            async move { axum::Json(m) }
1671        }),
1672    );
1673
1674    let proxy_authorize = proxy.clone();
1675    let router = router.route(
1676        "/authorize",
1677        axum::routing::get(
1678            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1679                let p = proxy_authorize.clone();
1680                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1681            },
1682        ),
1683    );
1684
1685    let proxy_token = proxy.clone();
1686    let token_http = http.clone();
1687    let router = router.route(
1688        "/token",
1689        axum::routing::post(move |body: String| {
1690            let p = proxy_token.clone();
1691            let h = token_http.clone();
1692            async move { crate::oauth::handle_token(&h, &p, &body).await }
1693        })
1694        .layer(axum::middleware::from_fn(
1695            oauth_token_cache_headers_middleware,
1696        )),
1697    );
1698
1699    let proxy_register = proxy.clone();
1700    let router = router.route(
1701        "/register",
1702        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1703            let p = proxy_register;
1704            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1705        })
1706        .layer(axum::middleware::from_fn(
1707            oauth_token_cache_headers_middleware,
1708        )),
1709    );
1710
1711    let admin_routes_enabled = proxy.expose_admin_endpoints
1712        && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1713    if proxy.expose_admin_endpoints
1714        && !proxy.require_auth_on_admin_endpoints
1715        && proxy.allow_unauthenticated_admin_endpoints
1716    {
1717        // M3 escape-hatch in effect: validate() let this through because
1718        // the operator explicitly opted in. Surface it loudly at startup
1719        // so the choice is auditable in logs.
1720        tracing::warn!(
1721            "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1722             allow_unauthenticated_admin_endpoints opt-out; ensure an \
1723             authenticated reverse proxy fronts these routes"
1724        );
1725    }
1726
1727    let admin_router = if admin_routes_enabled {
1728        build_oauth_admin_router(proxy, http, auth_state)?
1729    } else {
1730        axum::Router::new()
1731    };
1732
1733    let router = router.merge(admin_router);
1734
1735    tracing::info!(
1736        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1737        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1738        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1739    );
1740    Ok(router)
1741}
1742
1743/// Build the optional `/introspect` + `/revoke` admin sub-router.
1744///
1745/// Layered with [`oauth_token_cache_headers_middleware`] so RFC 6749 §5.1
1746/// / RFC 6750 §5.4 cache headers are emitted, and conditionally with the
1747/// auth middleware when `proxy.require_auth_on_admin_endpoints` is set.
1748#[cfg(feature = "oauth")]
1749fn build_oauth_admin_router(
1750    proxy: &crate::oauth::OAuthProxyConfig,
1751    http: crate::oauth::OauthHttpClient,
1752    auth_state: Option<&Arc<AuthState>>,
1753) -> Result<axum::Router, McpxError> {
1754    let mut admin_router = axum::Router::new();
1755    if proxy.introspection_url.is_some() {
1756        let proxy_introspect = proxy.clone();
1757        let introspect_http = http.clone();
1758        admin_router = admin_router.route(
1759            "/introspect",
1760            axum::routing::post(move |body: String| {
1761                let p = proxy_introspect.clone();
1762                let h = introspect_http.clone();
1763                async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1764            }),
1765        );
1766    }
1767    if proxy.revocation_url.is_some() {
1768        let proxy_revoke = proxy.clone();
1769        let revoke_http = http;
1770        admin_router = admin_router.route(
1771            "/revoke",
1772            axum::routing::post(move |body: String| {
1773                let p = proxy_revoke.clone();
1774                let h = revoke_http.clone();
1775                async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1776            }),
1777        );
1778    }
1779
1780    let admin_router = admin_router.layer(axum::middleware::from_fn(
1781        oauth_token_cache_headers_middleware,
1782    ));
1783
1784    if proxy.require_auth_on_admin_endpoints {
1785        let Some(state) = auth_state else {
1786            return Err(McpxError::Startup(
1787                "oauth proxy admin endpoints require auth state".into(),
1788            ));
1789        };
1790        let state_for_mw = Arc::clone(state);
1791        Ok(
1792            admin_router.layer(axum::middleware::from_fn(move |req, next| {
1793                let s = Arc::clone(&state_for_mw);
1794                auth_middleware(s, req, next)
1795            })),
1796        )
1797    } else {
1798        Ok(admin_router)
1799    }
1800}
1801
1802/// Build the host allow-list for rmcp's DNS rebinding protection.
1803///
1804/// Includes loopback hosts by default, then augments with host/authority
1805/// derived from `public_url` and the server bind address.
1806fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1807    let mut hosts = vec![
1808        "localhost".to_owned(),
1809        "127.0.0.1".to_owned(),
1810        "::1".to_owned(),
1811    ];
1812
1813    if let Some(url) = public_url
1814        && let Ok(uri) = url.parse::<axum::http::Uri>()
1815        && let Some(authority) = uri.authority()
1816    {
1817        let host = authority.host().to_owned();
1818        if !hosts.iter().any(|h| h == &host) {
1819            hosts.push(host);
1820        }
1821
1822        let authority = authority.as_str().to_owned();
1823        if !hosts.iter().any(|h| h == &authority) {
1824            hosts.push(authority);
1825        }
1826    }
1827
1828    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1829        && let Some(authority) = uri.authority()
1830    {
1831        let host = authority.host().to_owned();
1832        if !hosts.iter().any(|h| h == &host) {
1833            hosts.push(host);
1834        }
1835
1836        let authority = authority.as_str().to_owned();
1837        if !hosts.iter().any(|h| h == &authority) {
1838            hosts.push(authority);
1839        }
1840    }
1841
1842    hosts
1843}
1844
1845// - TLS support -
1846
1847/// Implement axum's `Connected` trait for `TlsConnInfo` so that
1848/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
1849/// over our custom `TlsListener`.
1850///
1851/// The identity is read directly from the wrapping
1852/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
1853/// between the TLS connection and its mTLS identity. This eliminates the
1854/// previous shared-map approach which was vulnerable to ephemeral-port
1855/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
1856/// pair could alias a stale entry).
1857impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1858    for TlsConnInfo
1859{
1860    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1861        let addr = *target.remote_addr();
1862        let identity = target.io().identity().cloned();
1863        TlsConnInfo::new(addr, identity)
1864    }
1865}
1866
1867/// A TLS-wrapping listener that implements axum's `Listener` trait.
1868///
1869/// When mTLS is configured, verifies client certificates against the
1870/// configured CA and extracts the client identity at handshake time.
1871/// The extracted identity is bound to the connection itself via the
1872/// returned [`AuthenticatedTlsStream`], so it is impossible for an
1873/// unrelated connection to observe it.
1874struct TlsListener {
1875    inner: TcpListener,
1876    acceptor: tokio_rustls::TlsAcceptor,
1877    mtls_default_role: String,
1878}
1879
1880impl TlsListener {
1881    fn new(
1882        inner: TcpListener,
1883        cert_path: &Path,
1884        key_path: &Path,
1885        mtls_config: Option<&MtlsConfig>,
1886        crl_set: Option<Arc<CrlSet>>,
1887    ) -> anyhow::Result<Self> {
1888        // Install the ring crypto provider (ok to call multiple times).
1889        rustls::crypto::ring::default_provider()
1890            .install_default()
1891            .ok();
1892
1893        let certs = load_certs(cert_path)?;
1894        let key = load_key(key_path)?;
1895
1896        let mtls_default_role;
1897
1898        let tls_config = if let Some(mtls) = mtls_config {
1899            mtls_default_role = mtls.default_role.clone();
1900            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1901            {
1902                let Some(crl_set) = crl_set else {
1903                    return Err(anyhow::anyhow!(
1904                        "mTLS CRL verifier requested but CRL state was not initialized"
1905                    ));
1906                };
1907                Arc::new(DynamicClientCertVerifier::new(crl_set))
1908            } else {
1909                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1910                if mtls.required {
1911                    rustls::server::WebPkiClientVerifier::builder(root_store)
1912                        .build()
1913                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1914                } else {
1915                    rustls::server::WebPkiClientVerifier::builder(root_store)
1916                        .allow_unauthenticated()
1917                        .build()
1918                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1919                }
1920            };
1921
1922            tracing::info!(
1923                ca = %mtls.ca_cert_path.display(),
1924                required = mtls.required,
1925                crl_enabled = mtls.crl_enabled,
1926                "mTLS client auth configured"
1927            );
1928
1929            rustls::ServerConfig::builder_with_protocol_versions(&[
1930                &rustls::version::TLS12,
1931                &rustls::version::TLS13,
1932            ])
1933            .with_client_cert_verifier(verifier)
1934            .with_single_cert(certs, key)?
1935        } else {
1936            mtls_default_role = "viewer".to_owned();
1937            rustls::ServerConfig::builder_with_protocol_versions(&[
1938                &rustls::version::TLS12,
1939                &rustls::version::TLS13,
1940            ])
1941            .with_no_client_auth()
1942            .with_single_cert(certs, key)?
1943        };
1944
1945        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1946        tracing::info!(
1947            "TLS enabled (cert: {}, key: {})",
1948            cert_path.display(),
1949            key_path.display()
1950        );
1951        Ok(Self {
1952            inner,
1953            acceptor,
1954            mtls_default_role,
1955        })
1956    }
1957
1958    /// Extract the mTLS client cert identity from a completed TLS handshake.
1959    /// Returns `None` if no client certificate was presented or if the
1960    /// certificate could not be parsed into an [`AuthIdentity`].
1961    fn extract_handshake_identity(
1962        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1963        default_role: &str,
1964        addr: SocketAddr,
1965    ) -> Option<AuthIdentity> {
1966        let (_, server_conn) = tls_stream.get_ref();
1967        let cert_der = server_conn.peer_certificates()?.first()?;
1968        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
1969        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
1970        Some(id)
1971    }
1972}
1973
1974/// A TLS stream paired with the mTLS identity extracted at handshake time.
1975///
1976/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
1977/// identity travels with the connection itself. This replaces the previous
1978/// shared `MtlsIdentities` map, eliminating the
1979/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
1980/// reuse and removing the need for an LRU eviction policy.
1981///
1982/// The wrapper is `Unpin` (its inner stream is `Unpin` because
1983/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
1984/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
1985pub(crate) struct AuthenticatedTlsStream {
1986    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
1987    identity: Option<AuthIdentity>,
1988}
1989
1990impl AuthenticatedTlsStream {
1991    /// Returns the verified mTLS client identity, if any.
1992    #[must_use]
1993    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
1994        self.identity.as_ref()
1995    }
1996}
1997
1998impl std::fmt::Debug for AuthenticatedTlsStream {
1999    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2000        f.debug_struct("AuthenticatedTlsStream")
2001            .field("identity", &self.identity.as_ref().map(|id| &id.name))
2002            .finish_non_exhaustive()
2003    }
2004}
2005
2006impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2007    fn poll_read(
2008        mut self: Pin<&mut Self>,
2009        cx: &mut std::task::Context<'_>,
2010        buf: &mut tokio::io::ReadBuf<'_>,
2011    ) -> std::task::Poll<std::io::Result<()>> {
2012        Pin::new(&mut self.inner).poll_read(cx, buf)
2013    }
2014}
2015
2016impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2017    fn poll_write(
2018        mut self: Pin<&mut Self>,
2019        cx: &mut std::task::Context<'_>,
2020        buf: &[u8],
2021    ) -> std::task::Poll<std::io::Result<usize>> {
2022        Pin::new(&mut self.inner).poll_write(cx, buf)
2023    }
2024
2025    fn poll_flush(
2026        mut self: Pin<&mut Self>,
2027        cx: &mut std::task::Context<'_>,
2028    ) -> std::task::Poll<std::io::Result<()>> {
2029        Pin::new(&mut self.inner).poll_flush(cx)
2030    }
2031
2032    fn poll_shutdown(
2033        mut self: Pin<&mut Self>,
2034        cx: &mut std::task::Context<'_>,
2035    ) -> std::task::Poll<std::io::Result<()>> {
2036        Pin::new(&mut self.inner).poll_shutdown(cx)
2037    }
2038
2039    fn poll_write_vectored(
2040        mut self: Pin<&mut Self>,
2041        cx: &mut std::task::Context<'_>,
2042        bufs: &[std::io::IoSlice<'_>],
2043    ) -> std::task::Poll<std::io::Result<usize>> {
2044        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2045    }
2046
2047    fn is_write_vectored(&self) -> bool {
2048        self.inner.is_write_vectored()
2049    }
2050}
2051
2052impl axum::serve::Listener for TlsListener {
2053    type Io = AuthenticatedTlsStream;
2054    type Addr = SocketAddr;
2055
2056    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2057        loop {
2058            let (stream, addr) = match self.inner.accept().await {
2059                Ok(pair) => pair,
2060                Err(e) => {
2061                    tracing::debug!("TCP accept error: {e}");
2062                    continue;
2063                }
2064            };
2065            let tls_stream = match self.acceptor.accept(stream).await {
2066                Ok(s) => s,
2067                Err(e) => {
2068                    tracing::debug!("TLS handshake failed from {addr}: {e}");
2069                    continue;
2070                }
2071            };
2072            let identity =
2073                Self::extract_handshake_identity(&tls_stream, &self.mtls_default_role, addr);
2074            let wrapped = AuthenticatedTlsStream {
2075                inner: tls_stream,
2076                identity,
2077            };
2078            return (wrapped, addr);
2079        }
2080    }
2081
2082    fn local_addr(&self) -> std::io::Result<Self::Addr> {
2083        self.inner.local_addr()
2084    }
2085}
2086
2087fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2088    use rustls::pki_types::pem::PemObject;
2089    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2090        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2091        .collect::<Result<_, _>>()
2092        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2093    anyhow::ensure!(
2094        !certs.is_empty(),
2095        "no certificates found in {}",
2096        path.display()
2097    );
2098    Ok(certs)
2099}
2100
2101fn load_client_auth_roots(
2102    path: &Path,
2103) -> anyhow::Result<(
2104    Vec<rustls::pki_types::CertificateDer<'static>>,
2105    Arc<RootCertStore>,
2106)> {
2107    let ca_certs = load_certs(path)?;
2108    let mut root_store = RootCertStore::empty();
2109    for cert in &ca_certs {
2110        root_store
2111            .add(cert.clone())
2112            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2113    }
2114
2115    Ok((ca_certs, Arc::new(root_store)))
2116}
2117
2118fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2119    use rustls::pki_types::pem::PemObject;
2120    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2121        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2122}
2123
2124#[allow(clippy::unused_async)]
2125async fn healthz() -> impl IntoResponse {
2126    axum::Json(serde_json::json!({
2127        "status": "ok",
2128    }))
2129}
2130
2131/// Build the `/version` JSON payload for a given server name and version.
2132///
2133/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
2134/// read at compile time from the `MCPX_BUILD_SHA`, `MCPX_BUILD_TIME`, and
2135/// `MCPX_RUSTC_VERSION` env vars. Unset values resolve to `"unknown"`.
2136fn version_payload(name: &str, version: &str) -> serde_json::Value {
2137    serde_json::json!({
2138        "name": name,
2139        "version": version,
2140        "build_git_sha": option_env!("MCPX_BUILD_SHA").unwrap_or("unknown"),
2141        "build_timestamp": option_env!("MCPX_BUILD_TIME").unwrap_or("unknown"),
2142        "rust_version": option_env!("MCPX_RUSTC_VERSION").unwrap_or("unknown"),
2143        "mcpx_version": env!("CARGO_PKG_VERSION"),
2144    })
2145}
2146
2147/// Pre-serialize the `/version` payload to immutable bytes.
2148///
2149/// This is called once at router-build time so per-request handling can
2150/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
2151/// [`serde_json::Value`] on every hit.
2152///
2153/// Serialization of a flat `serde_json::Value` of static-string fields
2154/// cannot fail in practice; the fallback to `b"{}"` exists only to
2155/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
2156fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2157    let value = version_payload(name, version);
2158    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2159}
2160
2161async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2162    let status = check().await;
2163    let ready = status
2164        .get("ready")
2165        .and_then(serde_json::Value::as_bool)
2166        .unwrap_or(false);
2167    let code = if ready {
2168        axum::http::StatusCode::OK
2169    } else {
2170        axum::http::StatusCode::SERVICE_UNAVAILABLE
2171    };
2172    (code, axum::Json(status))
2173}
2174
2175/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2176///
2177/// On non-Unix platforms, only SIGINT is handled.
2178async fn shutdown_signal() {
2179    let ctrl_c = tokio::signal::ctrl_c();
2180
2181    #[cfg(unix)]
2182    {
2183        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2184            Ok(mut term) => {
2185                tokio::select! {
2186                    _ = ctrl_c => {}
2187                    _ = term.recv() => {}
2188                }
2189            }
2190            Err(e) => {
2191                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2192                ctrl_c.await.ok();
2193            }
2194        }
2195    }
2196
2197    #[cfg(not(unix))]
2198    {
2199        ctrl_c.await.ok();
2200    }
2201}
2202
2203// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2204
2205/// Middleware that validates the `Origin` header on incoming HTTP requests.
2206///
2207/// Record HTTP request metrics (method, path, status, duration).
2208#[cfg(feature = "metrics")]
2209async fn metrics_middleware(
2210    metrics: Arc<crate::metrics::McpMetrics>,
2211    req: Request<Body>,
2212    next: Next,
2213) -> axum::response::Response {
2214    let method = req.method().to_string();
2215    let path = req.uri().path().to_owned();
2216    let start = std::time::Instant::now();
2217
2218    let response = next.run(req).await;
2219
2220    let status = response.status().as_u16().to_string();
2221    let duration = start.elapsed().as_secs_f64();
2222
2223    metrics
2224        .http_requests_total
2225        .with_label_values(&[&method, &path, &status])
2226        .inc();
2227    metrics
2228        .http_request_duration_seconds
2229        .with_label_values(&[&method, &path])
2230        .observe(duration);
2231
2232    response
2233}
2234
2235/// OWASP security header hardening applied to every response.
2236///
2237/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2238/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2239/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2240/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2241/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2242///
2243/// Each header's value can be customised via [`SecurityHeadersConfig`]
2244/// on [`McpServerConfig`]. See that type for the three-state semantic
2245/// (`None` = default, `Some("")` = omit, `Some(v)` = override).
2246async fn security_headers_middleware(
2247    is_tls: bool,
2248    cfg: Arc<SecurityHeadersConfig>,
2249    req: Request<Body>,
2250    next: Next,
2251) -> axum::response::Response {
2252    use axum::http::{HeaderName, header};
2253
2254    let mut resp = next.run(req).await;
2255    let headers = resp.headers_mut();
2256
2257    // Strip server identity headers to reduce information leakage.
2258    headers.remove(header::SERVER);
2259    headers.remove(HeaderName::from_static("x-powered-by"));
2260
2261    apply_security_header(
2262        headers,
2263        header::X_CONTENT_TYPE_OPTIONS,
2264        cfg.x_content_type_options.as_deref(),
2265        "nosniff",
2266    );
2267    apply_security_header(
2268        headers,
2269        header::X_FRAME_OPTIONS,
2270        cfg.x_frame_options.as_deref(),
2271        "deny",
2272    );
2273    apply_security_header(
2274        headers,
2275        header::CACHE_CONTROL,
2276        cfg.cache_control.as_deref(),
2277        "no-store, max-age=0",
2278    );
2279    apply_security_header(
2280        headers,
2281        header::REFERRER_POLICY,
2282        cfg.referrer_policy.as_deref(),
2283        "no-referrer",
2284    );
2285    apply_security_header(
2286        headers,
2287        HeaderName::from_static("cross-origin-opener-policy"),
2288        cfg.cross_origin_opener_policy.as_deref(),
2289        "same-origin",
2290    );
2291    apply_security_header(
2292        headers,
2293        HeaderName::from_static("cross-origin-resource-policy"),
2294        cfg.cross_origin_resource_policy.as_deref(),
2295        "same-origin",
2296    );
2297    apply_security_header(
2298        headers,
2299        HeaderName::from_static("cross-origin-embedder-policy"),
2300        cfg.cross_origin_embedder_policy.as_deref(),
2301        "require-corp",
2302    );
2303    apply_security_header(
2304        headers,
2305        HeaderName::from_static("permissions-policy"),
2306        cfg.permissions_policy.as_deref(),
2307        "accelerometer=(), camera=(), geolocation=(), microphone=()",
2308    );
2309    apply_security_header(
2310        headers,
2311        HeaderName::from_static("x-permitted-cross-domain-policies"),
2312        cfg.x_permitted_cross_domain_policies.as_deref(),
2313        "none",
2314    );
2315    apply_security_header(
2316        headers,
2317        HeaderName::from_static("content-security-policy"),
2318        cfg.content_security_policy.as_deref(),
2319        "default-src 'none'; frame-ancestors 'none'",
2320    );
2321    apply_security_header(
2322        headers,
2323        HeaderName::from_static("x-dns-prefetch-control"),
2324        cfg.x_dns_prefetch_control.as_deref(),
2325        "off",
2326    );
2327
2328    if is_tls {
2329        apply_security_header(
2330            headers,
2331            header::STRICT_TRANSPORT_SECURITY,
2332            cfg.strict_transport_security.as_deref(),
2333            "max-age=63072000; includeSubDomains",
2334        );
2335    }
2336
2337    resp
2338}
2339
2340/// Set a single security header on the response, honouring the
2341/// three-state override semantic (None = default, Some("") = omit,
2342/// Some(value) = override).
2343///
2344/// Defence-in-depth: if an override value somehow reaches this point
2345/// despite [`validate_security_headers`] having approved it (e.g. a
2346/// runtime mutation on a non-`Validated` field), we log at error level
2347/// and fall back to the static default rather than panicking. The
2348/// `Validated<McpServerConfig>` type makes that path unreachable in
2349/// well-typed code paths.
2350fn apply_security_header(
2351    headers: &mut axum::http::HeaderMap,
2352    name: axum::http::HeaderName,
2353    override_value: Option<&str>,
2354    default: &'static str,
2355) {
2356    use axum::http::HeaderValue;
2357
2358    match override_value {
2359        None => {
2360            headers.insert(name, HeaderValue::from_static(default));
2361        }
2362        Some("") => {
2363            // Operator explicitly opted out of this header.
2364        }
2365        Some(v) => match HeaderValue::from_str(v) {
2366            Ok(hv) => {
2367                headers.insert(name, hv);
2368            }
2369            Err(err) => {
2370                tracing::error!(
2371                    header = %name,
2372                    error = %err,
2373                    "invalid security header override reached middleware; using default"
2374                );
2375                headers.insert(name, HeaderValue::from_static(default));
2376            }
2377        },
2378    }
2379}
2380
2381/// Validate every non-empty entry in a [`SecurityHeadersConfig`].
2382///
2383/// - `None` and `Some("")` are accepted unconditionally (use-default and
2384///   omit, respectively).
2385/// - `Some(v)` is rejected if `axum::http::HeaderValue::from_str(v)` fails.
2386/// - `strict_transport_security` additionally rejects any value
2387///   containing `preload` (case-insensitive). Operators who genuinely
2388///   want to commit to the HSTS preload list must do so via a future
2389///   explicit `with_hsts_preload(true)` builder, not by smuggling
2390///   `preload` through this knob.
2391fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2392    use axum::http::HeaderValue;
2393
2394    let fields: &[(&str, Option<&str>)] = &[
2395        (
2396            "x_content_type_options",
2397            cfg.x_content_type_options.as_deref(),
2398        ),
2399        ("x_frame_options", cfg.x_frame_options.as_deref()),
2400        ("cache_control", cfg.cache_control.as_deref()),
2401        ("referrer_policy", cfg.referrer_policy.as_deref()),
2402        (
2403            "cross_origin_opener_policy",
2404            cfg.cross_origin_opener_policy.as_deref(),
2405        ),
2406        (
2407            "cross_origin_resource_policy",
2408            cfg.cross_origin_resource_policy.as_deref(),
2409        ),
2410        (
2411            "cross_origin_embedder_policy",
2412            cfg.cross_origin_embedder_policy.as_deref(),
2413        ),
2414        ("permissions_policy", cfg.permissions_policy.as_deref()),
2415        (
2416            "x_permitted_cross_domain_policies",
2417            cfg.x_permitted_cross_domain_policies.as_deref(),
2418        ),
2419        (
2420            "content_security_policy",
2421            cfg.content_security_policy.as_deref(),
2422        ),
2423        (
2424            "x_dns_prefetch_control",
2425            cfg.x_dns_prefetch_control.as_deref(),
2426        ),
2427        (
2428            "strict_transport_security",
2429            cfg.strict_transport_security.as_deref(),
2430        ),
2431    ];
2432
2433    for (field, value) in fields {
2434        let Some(v) = value else { continue };
2435        if v.is_empty() {
2436            continue;
2437        }
2438        if let Err(err) = HeaderValue::from_str(v) {
2439            return Err(McpxError::Config(format!(
2440                "invalid security_headers.{field}: {err}"
2441            )));
2442        }
2443    }
2444
2445    if let Some(v) = cfg.strict_transport_security.as_deref()
2446        && !v.is_empty()
2447        && v.to_ascii_lowercase().contains("preload")
2448    {
2449        return Err(McpxError::Config(format!(
2450            "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2451             HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2452        )));
2453    }
2454
2455    Ok(())
2456}
2457
2458/// Append RFC 6749 §5.1 / RFC 6750 §5.4 cache and `Vary` headers required
2459/// on OAuth token-issuing responses.
2460///
2461/// `Cache-Control: no-store, max-age=0` is already applied globally by
2462/// [`security_headers_middleware`]; this middleware adds:
2463///
2464/// - `Pragma: no-cache` -- mandated by RFC 6749 §5.1 for HTTP/1.0 caches.
2465/// - `Vary: Authorization` -- mandated by RFC 6750 §5.4 for endpoints
2466///   whose response depends on the `Authorization` header.
2467///
2468/// Applied only to the OAuth proxy token-class endpoints (`/token`,
2469/// `/register`, `/introspect`, `/revoke`). `Vary` is appended (not
2470/// inserted) so any `Vary` value already present (e.g. `Accept-Encoding`
2471/// from a compression layer, or `Origin` from a CORS layer) is preserved.
2472#[cfg(feature = "oauth")]
2473async fn oauth_token_cache_headers_middleware(
2474    req: Request<Body>,
2475    next: Next,
2476) -> axum::response::Response {
2477    use axum::http::{HeaderValue, header};
2478
2479    let mut resp = next.run(req).await;
2480    let headers = resp.headers_mut();
2481    headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2482    headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2483    resp
2484}
2485
2486/// Per the MCP spec: if the Origin header is present and its value is not in
2487/// the allowed list, respond with 403 Forbidden. Requests without an Origin
2488/// header are allowed through (e.g. non-browser clients like curl, SDKs).
2489async fn origin_check_middleware(
2490    allowed: Arc<[String]>,
2491    log_request_headers: bool,
2492    req: Request<Body>,
2493    next: Next,
2494) -> axum::response::Response {
2495    let method = req.method().clone();
2496    let path = req.uri().path().to_owned();
2497
2498    log_incoming_request(&method, &path, req.headers(), log_request_headers);
2499
2500    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2501        let origin_str = origin.to_str().unwrap_or("");
2502        if !allowed.iter().any(|a| a == origin_str) {
2503            tracing::warn!(
2504                origin = origin_str,
2505                %method,
2506                %path,
2507                allowed = ?&*allowed,
2508                "rejected request: Origin not allowed"
2509            );
2510            return (
2511                axum::http::StatusCode::FORBIDDEN,
2512                "Forbidden: Origin not allowed",
2513            )
2514                .into_response();
2515        }
2516    }
2517    next.run(req).await
2518}
2519
2520/// Emit a DEBUG log for an incoming request, optionally including the full
2521/// (redacted) header set.
2522fn log_incoming_request(
2523    method: &axum::http::Method,
2524    path: &str,
2525    headers: &axum::http::HeaderMap,
2526    log_request_headers: bool,
2527) {
2528    if log_request_headers {
2529        tracing::debug!(
2530            %method,
2531            %path,
2532            headers = %format_request_headers_for_log(headers),
2533            "incoming request"
2534        );
2535    } else {
2536        tracing::debug!(%method, %path, "incoming request");
2537    }
2538}
2539
2540fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2541    headers
2542        .iter()
2543        .map(|(k, v)| {
2544            let name = k.as_str();
2545            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2546                format!("{name}: [REDACTED]")
2547            } else {
2548                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2549            }
2550        })
2551        .collect::<Vec<_>>()
2552        .join(", ")
2553}
2554
2555// -- stdio transport --
2556
2557/// Serve an MCP server over stdin/stdout (stdio transport).
2558///
2559/// # Security warnings
2560///
2561/// - **No authentication**: the parent process has full, unrestricted access.
2562/// - **No RBAC**: all tools are available regardless of policy.
2563/// - **No TLS**: messages travel over OS pipes in plaintext.
2564/// - **Single client**: only the parent process can connect.
2565/// - **No Origin validation**: not applicable to stdio.
2566///
2567/// Use this only when the MCP client spawns the server as a trusted subprocess
2568/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
2569/// use `serve()` (Streamable HTTP) instead.
2570///
2571/// # Errors
2572///
2573/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
2574/// transport disconnects unexpectedly.
2575// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
2576// macro expansion in this 18-line function (info/warn/info + two matches).
2577// There is nothing meaningful to extract; the allow stays.
2578#[allow(
2579    clippy::cognitive_complexity,
2580    reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
2581)]
2582pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2583where
2584    H: ServerHandler + 'static,
2585{
2586    use rmcp::ServiceExt as _;
2587
2588    tracing::info!("stdio transport: serving on stdin/stdout");
2589    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2590
2591    let transport = rmcp::transport::io::stdio();
2592
2593    let service = handler
2594        .serve(transport)
2595        .await
2596        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2597
2598    if let Err(e) = service.waiting().await {
2599        tracing::warn!(error = %e, "stdio session ended with error");
2600    }
2601    tracing::info!("stdio session ended");
2602    Ok(())
2603}
2604
2605#[cfg(test)]
2606mod tests {
2607    #![allow(
2608        clippy::unwrap_used,
2609        clippy::expect_used,
2610        clippy::panic,
2611        clippy::indexing_slicing,
2612        clippy::unwrap_in_result,
2613        clippy::print_stdout,
2614        clippy::print_stderr,
2615        deprecated,
2616        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2617    )]
2618    use std::sync::Arc;
2619
2620    use axum::{
2621        body::Body,
2622        http::{Request, StatusCode, header},
2623        response::IntoResponse,
2624    };
2625    use http_body_util::BodyExt;
2626    use tower::ServiceExt as _;
2627
2628    use super::*;
2629
2630    // -- McpServerConfig --
2631
2632    #[test]
2633    fn server_config_new_defaults() {
2634        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2635        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2636        assert_eq!(cfg.name, "test-server");
2637        assert_eq!(cfg.version, "1.0.0");
2638        assert!(cfg.tls_cert_path.is_none());
2639        assert!(cfg.tls_key_path.is_none());
2640        assert!(cfg.auth.is_none());
2641        assert!(cfg.rbac.is_none());
2642        assert!(cfg.allowed_origins.is_empty());
2643        assert!(cfg.tool_rate_limit.is_none());
2644        assert!(cfg.readiness_check.is_none());
2645        assert_eq!(cfg.max_request_body, 1024 * 1024);
2646        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2647        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2648        assert!(!cfg.log_request_headers);
2649    }
2650
2651    #[test]
2652    fn validate_consumes_and_proves() {
2653        // Valid config -> Validated wrapper, original is consumed.
2654        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2655        let validated = cfg.validate().expect("valid config");
2656        // as_inner() gives read-only access to inner fields.
2657        assert_eq!(validated.as_inner().name, "test-server");
2658        // into_inner recovers the raw value.
2659        let raw = validated.into_inner();
2660        assert_eq!(raw.name, "test-server");
2661
2662        // Invalid config (zero max_request_body) -> Err.
2663        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2664        bad.max_request_body = 0;
2665        assert!(bad.validate().is_err(), "zero body cap must fail validate");
2666    }
2667
2668    #[test]
2669    fn validate_rejects_zero_max_concurrent_requests() {
2670        let cfg =
2671            McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
2672        let err = cfg.validate().expect_err("zero concurrency cap must fail");
2673        assert!(
2674            format!("{err}").contains("max_concurrent_requests"),
2675            "error should mention max_concurrent_requests, got: {err}"
2676        );
2677    }
2678
2679    #[test]
2680    fn validate_rejects_zero_max_tracked_keys() {
2681        let rl = crate::auth::RateLimitConfig {
2682            max_tracked_keys: 0,
2683            ..Default::default()
2684        };
2685        let auth_cfg = AuthConfig {
2686            enabled: true,
2687            rate_limit: Some(rl),
2688            ..Default::default()
2689        };
2690        let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
2691        let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
2692        assert!(
2693            format!("{err}").contains("max_tracked_keys"),
2694            "error should mention max_tracked_keys, got: {err}"
2695        );
2696    }
2697
2698    #[test]
2699    fn derive_allowed_hosts_includes_public_host() {
2700        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2701        assert!(
2702            hosts.iter().any(|h| h == "mcp.example.com"),
2703            "public_url host must be allowed"
2704        );
2705    }
2706
2707    #[test]
2708    fn derive_allowed_hosts_includes_bind_authority() {
2709        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2710        assert!(
2711            hosts.iter().any(|h| h == "127.0.0.1"),
2712            "bind host must be allowed"
2713        );
2714        assert!(
2715            hosts.iter().any(|h| h == "127.0.0.1:8080"),
2716            "bind authority must be allowed"
2717        );
2718    }
2719
2720    // -- healthz --
2721
2722    #[tokio::test]
2723    async fn healthz_returns_ok_json() {
2724        let resp = healthz().await.into_response();
2725        assert_eq!(resp.status(), StatusCode::OK);
2726        let body = resp.into_body().collect().await.unwrap().to_bytes();
2727        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2728        assert_eq!(json["status"], "ok");
2729        assert!(
2730            json.get("name").is_none(),
2731            "healthz must not expose server name"
2732        );
2733        assert!(
2734            json.get("version").is_none(),
2735            "healthz must not expose version"
2736        );
2737    }
2738
2739    // -- readyz --
2740
2741    #[tokio::test]
2742    async fn readyz_returns_ok_when_ready() {
2743        let check: ReadinessCheck =
2744            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2745        let resp = readyz(check).await.into_response();
2746        assert_eq!(resp.status(), StatusCode::OK);
2747        let body = resp.into_body().collect().await.unwrap().to_bytes();
2748        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2749        assert_eq!(json["ready"], true);
2750        assert!(
2751            json.get("name").is_none(),
2752            "readyz must not expose server name"
2753        );
2754        assert!(
2755            json.get("version").is_none(),
2756            "readyz must not expose version"
2757        );
2758        assert_eq!(json["db"], "connected");
2759    }
2760
2761    #[tokio::test]
2762    async fn readyz_returns_503_when_not_ready() {
2763        let check: ReadinessCheck =
2764            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2765        let resp = readyz(check).await.into_response();
2766        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2767    }
2768
2769    #[tokio::test]
2770    async fn readyz_returns_503_when_ready_missing() {
2771        let check: ReadinessCheck =
2772            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2773        let resp = readyz(check).await.into_response();
2774        // Missing "ready" field defaults to false -> 503
2775        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2776    }
2777
2778    // -- origin_check_middleware --
2779
2780    /// Build a test router with origin check middleware and a simple handler.
2781    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2782        let allowed: Arc<[String]> = Arc::from(origins);
2783        axum::Router::new()
2784            .route("/test", axum::routing::get(|| async { "ok" }))
2785            .layer(axum::middleware::from_fn(move |req, next| {
2786                let a = Arc::clone(&allowed);
2787                origin_check_middleware(a, log_request_headers, req, next)
2788            }))
2789    }
2790
2791    #[tokio::test]
2792    async fn origin_allowed_passes() {
2793        let app = origin_router(vec!["http://localhost:3000".into()], false);
2794        let req = Request::builder()
2795            .uri("/test")
2796            .header(header::ORIGIN, "http://localhost:3000")
2797            .body(Body::empty())
2798            .unwrap();
2799        let resp = app.oneshot(req).await.unwrap();
2800        assert_eq!(resp.status(), StatusCode::OK);
2801    }
2802
2803    #[tokio::test]
2804    async fn origin_rejected_returns_403() {
2805        let app = origin_router(vec!["http://localhost:3000".into()], false);
2806        let req = Request::builder()
2807            .uri("/test")
2808            .header(header::ORIGIN, "http://evil.com")
2809            .body(Body::empty())
2810            .unwrap();
2811        let resp = app.oneshot(req).await.unwrap();
2812        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2813    }
2814
2815    #[tokio::test]
2816    async fn no_origin_header_passes() {
2817        let app = origin_router(vec!["http://localhost:3000".into()], false);
2818        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2819        let resp = app.oneshot(req).await.unwrap();
2820        assert_eq!(resp.status(), StatusCode::OK);
2821    }
2822
2823    #[tokio::test]
2824    async fn empty_allowlist_rejects_any_origin() {
2825        let app = origin_router(vec![], false);
2826        let req = Request::builder()
2827            .uri("/test")
2828            .header(header::ORIGIN, "http://anything.com")
2829            .body(Body::empty())
2830            .unwrap();
2831        let resp = app.oneshot(req).await.unwrap();
2832        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2833    }
2834
2835    #[tokio::test]
2836    async fn empty_allowlist_passes_without_origin() {
2837        let app = origin_router(vec![], false);
2838        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2839        let resp = app.oneshot(req).await.unwrap();
2840        assert_eq!(resp.status(), StatusCode::OK);
2841    }
2842
2843    #[test]
2844    fn format_request_headers_redacts_sensitive_values() {
2845        let mut headers = axum::http::HeaderMap::new();
2846        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2847        headers.insert("cookie", "sid=abc".parse().unwrap());
2848        headers.insert("x-request-id", "req-123".parse().unwrap());
2849
2850        let out = format_request_headers_for_log(&headers);
2851        assert!(out.contains("authorization: [REDACTED]"));
2852        assert!(out.contains("cookie: [REDACTED]"));
2853        assert!(out.contains("x-request-id: req-123"));
2854        assert!(!out.contains("secret-token"));
2855    }
2856
2857    // -- security_headers_middleware --
2858
2859    fn security_router(is_tls: bool) -> axum::Router {
2860        security_router_with(is_tls, SecurityHeadersConfig::default())
2861    }
2862
2863    fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
2864        let cfg = Arc::new(cfg);
2865        axum::Router::new()
2866            .route("/test", axum::routing::get(|| async { "ok" }))
2867            .layer(axum::middleware::from_fn(move |req, next| {
2868                let c = Arc::clone(&cfg);
2869                security_headers_middleware(is_tls, c, req, next)
2870            }))
2871    }
2872
2873    #[tokio::test]
2874    async fn security_headers_set_on_response() {
2875        let app = security_router(false);
2876        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2877        let resp = app.oneshot(req).await.unwrap();
2878        assert_eq!(resp.status(), StatusCode::OK);
2879
2880        let h = resp.headers();
2881        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
2882        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
2883        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
2884        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
2885        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
2886        assert_eq!(
2887            h.get("cross-origin-resource-policy").unwrap(),
2888            "same-origin"
2889        );
2890        assert_eq!(
2891            h.get("cross-origin-embedder-policy").unwrap(),
2892            "require-corp"
2893        );
2894        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
2895        assert!(
2896            h.get("permissions-policy")
2897                .unwrap()
2898                .to_str()
2899                .unwrap()
2900                .contains("camera=()"),
2901            "permissions-policy must restrict browser features"
2902        );
2903        assert_eq!(
2904            h.get("content-security-policy").unwrap(),
2905            "default-src 'none'; frame-ancestors 'none'"
2906        );
2907        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
2908        // No HSTS when TLS is off.
2909        assert!(h.get("strict-transport-security").is_none());
2910    }
2911
2912    #[tokio::test]
2913    async fn hsts_set_when_tls_enabled() {
2914        let app = security_router(true);
2915        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2916        let resp = app.oneshot(req).await.unwrap();
2917
2918        let hsts = resp.headers().get("strict-transport-security").unwrap();
2919        assert!(
2920            hsts.to_str().unwrap().contains("max-age=63072000"),
2921            "HSTS must set 2-year max-age"
2922        );
2923    }
2924
2925    // -- SecurityHeadersConfig validation + override semantics --
2926
2927    /// Build a minimal config with a custom SecurityHeadersConfig and
2928    /// drive it through `check()`. Returns the result so individual
2929    /// tests can assert on success or specific error messages.
2930    fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
2931        let cfg =
2932            McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
2933        cfg.check()
2934    }
2935
2936    #[test]
2937    fn security_headers_config_default_validates() {
2938        check_with_security_headers(SecurityHeadersConfig::default())
2939            .expect("default SecurityHeadersConfig must validate");
2940    }
2941
2942    #[test]
2943    fn security_headers_config_validate_accepts_empty_string() {
2944        // All twelve fields explicitly set to "" -> omit-everything mode.
2945        let h = SecurityHeadersConfig {
2946            x_content_type_options: Some(String::new()),
2947            x_frame_options: Some(String::new()),
2948            cache_control: Some(String::new()),
2949            referrer_policy: Some(String::new()),
2950            cross_origin_opener_policy: Some(String::new()),
2951            cross_origin_resource_policy: Some(String::new()),
2952            cross_origin_embedder_policy: Some(String::new()),
2953            permissions_policy: Some(String::new()),
2954            x_permitted_cross_domain_policies: Some(String::new()),
2955            content_security_policy: Some(String::new()),
2956            x_dns_prefetch_control: Some(String::new()),
2957            strict_transport_security: Some(String::new()),
2958        };
2959        check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
2960    }
2961
2962    #[test]
2963    fn security_headers_config_validate_rejects_bad_value() {
2964        // 0x07 (BEL) is not a valid HTTP header value char.
2965        let h = SecurityHeadersConfig {
2966            referrer_policy: Some("\u{0007}".into()),
2967            ..SecurityHeadersConfig::default()
2968        };
2969        let err = check_with_security_headers(h)
2970            .expect_err("control char in referrer_policy must reject");
2971        let msg = err.to_string();
2972        assert!(
2973            msg.contains("referrer_policy"),
2974            "error must name the offending field, got: {msg}"
2975        );
2976    }
2977
2978    #[test]
2979    fn security_headers_config_validate_rejects_hsts_preload() {
2980        let h = SecurityHeadersConfig {
2981            strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
2982            ..SecurityHeadersConfig::default()
2983        };
2984        let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
2985        let msg = err.to_string();
2986        assert!(
2987            msg.contains("strict_transport_security"),
2988            "error must name the field, got: {msg}"
2989        );
2990        assert!(
2991            msg.to_lowercase().contains("preload"),
2992            "error must mention `preload`, got: {msg}"
2993        );
2994    }
2995
2996    #[test]
2997    fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
2998        // Case-insensitive match.
2999        let h = SecurityHeadersConfig {
3000            strict_transport_security: Some("max-age=600; PRELOAD".into()),
3001            ..SecurityHeadersConfig::default()
3002        };
3003        check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3004    }
3005
3006    #[tokio::test]
3007    async fn security_headers_override_honored() {
3008        // Override X-Frame-Options to SAMEORIGIN.
3009        let h = SecurityHeadersConfig {
3010            x_frame_options: Some("SAMEORIGIN".into()),
3011            ..SecurityHeadersConfig::default()
3012        };
3013        let app = security_router_with(false, h);
3014        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3015        let resp = app.oneshot(req).await.unwrap();
3016        assert_eq!(resp.status(), StatusCode::OK);
3017
3018        let xfo = resp.headers().get("x-frame-options").unwrap();
3019        assert_eq!(xfo, "SAMEORIGIN");
3020    }
3021
3022    #[tokio::test]
3023    async fn security_headers_empty_string_omits() {
3024        // Empty string on referrer-policy -> header absent.
3025        let h = SecurityHeadersConfig {
3026            referrer_policy: Some(String::new()),
3027            ..SecurityHeadersConfig::default()
3028        };
3029        let app = security_router_with(false, h);
3030        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3031        let resp = app.oneshot(req).await.unwrap();
3032        assert_eq!(resp.status(), StatusCode::OK);
3033
3034        assert!(
3035            resp.headers().get("referrer-policy").is_none(),
3036            "Some(\"\") must omit the header"
3037        );
3038        // Other defaults should still be present.
3039        assert_eq!(
3040            resp.headers().get("x-content-type-options").unwrap(),
3041            "nosniff"
3042        );
3043    }
3044
3045    #[tokio::test]
3046    async fn security_headers_hsts_only_when_tls() {
3047        // HSTS override is irrelevant when TLS is off.
3048        let h = SecurityHeadersConfig {
3049            strict_transport_security: Some("max-age=600".into()),
3050            ..SecurityHeadersConfig::default()
3051        };
3052        let app = security_router_with(false, h);
3053        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3054        let resp = app.oneshot(req).await.unwrap();
3055        assert!(
3056            resp.headers().get("strict-transport-security").is_none(),
3057            "HSTS must remain absent on plaintext deployments even with override"
3058        );
3059    }
3060
3061    // -- oauth_token_cache_headers_middleware --
3062
3063    #[cfg(feature = "oauth")]
3064    #[tokio::test]
3065    async fn oauth_token_cache_headers_set_pragma_and_vary() {
3066        let app = axum::Router::new()
3067            .route("/token", axum::routing::post(|| async { "{}" }))
3068            .layer(axum::middleware::from_fn(
3069                oauth_token_cache_headers_middleware,
3070            ));
3071        let req = Request::builder()
3072            .method("POST")
3073            .uri("/token")
3074            .body(Body::from("{}"))
3075            .unwrap();
3076        let resp = app.oneshot(req).await.unwrap();
3077        assert_eq!(resp.status(), StatusCode::OK);
3078
3079        let h = resp.headers();
3080        assert_eq!(
3081            h.get("pragma").unwrap(),
3082            "no-cache",
3083            "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3084        );
3085        let vary_values: Vec<String> = h
3086            .get_all("vary")
3087            .iter()
3088            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3089            .collect();
3090        assert!(
3091            vary_values
3092                .iter()
3093                .any(|v| v.eq_ignore_ascii_case("Authorization")),
3094            "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3095        );
3096    }
3097
3098    #[cfg(feature = "oauth")]
3099    #[tokio::test]
3100    async fn oauth_token_cache_headers_preserve_existing_vary() {
3101        // Simulates a handler/layer that already set `Vary: Accept-Encoding`
3102        // (e.g. compression). Our middleware must APPEND, not REPLACE.
3103        let app = axum::Router::new()
3104            .route(
3105                "/token",
3106                axum::routing::post(|| async {
3107                    axum::response::Response::builder()
3108                        .header("vary", "Accept-Encoding")
3109                        .body(axum::body::Body::from("{}"))
3110                        .unwrap()
3111                }),
3112            )
3113            .layer(axum::middleware::from_fn(
3114                oauth_token_cache_headers_middleware,
3115            ));
3116        let req = Request::builder()
3117            .method("POST")
3118            .uri("/token")
3119            .body(Body::empty())
3120            .unwrap();
3121        let resp = app.oneshot(req).await.unwrap();
3122
3123        let vary: Vec<String> = resp
3124            .headers()
3125            .get_all("vary")
3126            .iter()
3127            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3128            .collect();
3129        assert!(
3130            vary.iter().any(|v| v.contains("Accept-Encoding")),
3131            "must preserve pre-existing Vary value, got {vary:?}"
3132        );
3133        assert!(
3134            vary.iter().any(|v| v.contains("Authorization")),
3135            "must append Authorization to Vary, got {vary:?}"
3136        );
3137    }
3138
3139    // -- version endpoint --
3140
3141    #[test]
3142    fn version_payload_contains_expected_fields() {
3143        let v = version_payload("my-server", "1.2.3");
3144        assert_eq!(v["name"], "my-server");
3145        assert_eq!(v["version"], "1.2.3");
3146        assert!(v["build_git_sha"].is_string());
3147        assert!(v["build_timestamp"].is_string());
3148        assert!(v["rust_version"].is_string());
3149        assert!(v["mcpx_version"].is_string());
3150    }
3151
3152    // -- concurrency limit layer --
3153
3154    #[tokio::test]
3155    async fn concurrency_limit_layer_composes_and_serves() {
3156        // We only assert the layer stack compiles and a single request
3157        // below the cap still succeeds. True back-pressure behaviour
3158        // requires a live HTTP server and is covered by integration tests.
3159        let app = axum::Router::new()
3160            .route("/ok", axum::routing::get(|| async { "ok" }))
3161            .layer(
3162                tower::ServiceBuilder::new()
3163                    .layer(axum::error_handling::HandleErrorLayer::new(
3164                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3165                    ))
3166                    .layer(tower::load_shed::LoadShedLayer::new())
3167                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3168            );
3169        let resp = app
3170            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3171            .await
3172            .unwrap();
3173        assert_eq!(resp.status(), StatusCode::OK);
3174    }
3175
3176    // -- compression layer --
3177
3178    #[tokio::test]
3179    async fn compression_layer_gzip_encodes_response() {
3180        use tower_http::compression::Predicate as _;
3181
3182        let big_body = "a".repeat(4096);
3183        let app = axum::Router::new()
3184            .route(
3185                "/big",
3186                axum::routing::get(move || {
3187                    let body = big_body.clone();
3188                    async move { body }
3189                }),
3190            )
3191            .layer(
3192                tower_http::compression::CompressionLayer::new()
3193                    .gzip(true)
3194                    .br(true)
3195                    .compress_when(
3196                        tower_http::compression::DefaultPredicate::new()
3197                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3198                    ),
3199            );
3200
3201        let req = Request::builder()
3202            .uri("/big")
3203            .header(header::ACCEPT_ENCODING, "gzip")
3204            .body(Body::empty())
3205            .unwrap();
3206        let resp = app.oneshot(req).await.unwrap();
3207        assert_eq!(resp.status(), StatusCode::OK);
3208        assert_eq!(
3209            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3210            "gzip"
3211        );
3212    }
3213}