Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13    ServerHandler,
14    transport::streamable_http_server::{
15        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16    },
17};
18use rustls::RootCertStore;
19use tokio::{
20    net::TcpListener,
21    sync::{Semaphore, mpsc},
22};
23use tokio_util::sync::CancellationToken;
24
25use crate::{
26    auth::{
27        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
28        build_rate_limiter, extract_mtls_identity,
29    },
30    error::McpxError,
31    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
32    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
33};
34
35/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
36/// at the public API boundary, flattening the chain via the alternate
37/// formatter so callers see the full causal path.
38#[allow(
39    clippy::needless_pass_by_value,
40    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
41)]
42fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
43    McpxError::Startup(format!("{e:#}"))
44}
45
46/// Map a `std::io::Error` produced during server startup into a public
47/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
48/// `From` impl here because startup-phase IO errors (bind, listener) are
49/// semantically distinct from request-time IO errors and should surface
50/// the originating operation in the message.
51#[allow(
52    clippy::needless_pass_by_value,
53    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
54)]
55fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
56    McpxError::Startup(format!("{op}: {e}"))
57}
58
59/// Async readiness check callback for the `/readyz` endpoint.
60///
61/// Returns a JSON object with at least a `"ready"` boolean.
62/// When `ready` is false, the endpoint returns HTTP 503.
63pub type ReadinessCheck =
64    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
65
66/// Per-header overrides for the OWASP security headers emitted by the
67/// global response middleware.
68///
69/// Each field follows a three-state semantic:
70///
71/// | Value         | Behaviour                                                |
72/// |---------------|----------------------------------------------------------|
73/// | `None`        | Use the built-in default (current behaviour).            |
74/// | `Some("")`    | **Omit** the header entirely from responses.             |
75/// | `Some(value)` | Emit `header: value`. Validated at config-load time.     |
76///
77/// All non-empty values are validated via
78/// [`axum::http::HeaderValue::from_str`] inside
79/// [`McpServerConfig::validate`]; invalid values fail fast before the
80/// server starts accepting traffic.
81///
82/// `Strict-Transport-Security` has an additional rule: the substring
83/// `preload` (case-insensitive) is rejected. Operators who want to
84/// commit to the HSTS preload list must do so via a future explicit
85/// builder method, not by smuggling it through this knob.
86#[derive(Debug, Clone, Default)]
87#[non_exhaustive]
88pub struct SecurityHeadersConfig {
89    /// Override for `X-Content-Type-Options`. Default: `nosniff`.
90    pub x_content_type_options: Option<String>,
91    /// Override for `X-Frame-Options`. Default: `deny`.
92    pub x_frame_options: Option<String>,
93    /// Override for `Cache-Control`. Default: `no-store, max-age=0`.
94    pub cache_control: Option<String>,
95    /// Override for `Referrer-Policy`. Default: `no-referrer`.
96    pub referrer_policy: Option<String>,
97    /// Override for `Cross-Origin-Opener-Policy`. Default: `same-origin`.
98    pub cross_origin_opener_policy: Option<String>,
99    /// Override for `Cross-Origin-Resource-Policy`. Default: `same-origin`.
100    pub cross_origin_resource_policy: Option<String>,
101    /// Override for `Cross-Origin-Embedder-Policy`. Default: `require-corp`.
102    pub cross_origin_embedder_policy: Option<String>,
103    /// Override for `Permissions-Policy`. Default:
104    /// `accelerometer=(), camera=(), geolocation=(), microphone=()`.
105    pub permissions_policy: Option<String>,
106    /// Override for `X-Permitted-Cross-Domain-Policies`. Default: `none`.
107    pub x_permitted_cross_domain_policies: Option<String>,
108    /// Override for `Content-Security-Policy`. Default:
109    /// `default-src 'none'; frame-ancestors 'none'`.
110    pub content_security_policy: Option<String>,
111    /// Override for `X-DNS-Prefetch-Control`. Default: `off`.
112    pub x_dns_prefetch_control: Option<String>,
113    /// Override for `Strict-Transport-Security`. Default (TLS only):
114    /// `max-age=63072000; includeSubDomains`. Only emitted when TLS is
115    /// active; the override is ignored on plaintext deployments. The
116    /// substring `preload` (any case) is rejected by the validator.
117    pub strict_transport_security: Option<String>,
118}
119
120/// Configuration for the MCP server.
121#[allow(
122    missing_debug_implementations,
123    reason = "contains callback/trait objects that don't impl Debug"
124)]
125#[allow(
126    clippy::struct_excessive_bools,
127    reason = "server configuration naturally has many boolean feature flags"
128)]
129#[non_exhaustive]
130pub struct McpServerConfig {
131    /// Socket address the MCP HTTP server binds to.
132    #[deprecated(
133        since = "0.13.0",
134        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
135    )]
136    pub bind_addr: String,
137    /// Server name advertised via MCP `initialize`.
138    #[deprecated(
139        since = "0.13.0",
140        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
141    )]
142    pub name: String,
143    /// Server version advertised via MCP `initialize`.
144    #[deprecated(
145        since = "0.13.0",
146        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
147    )]
148    pub version: String,
149    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
150    #[deprecated(
151        since = "0.13.0",
152        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
153    )]
154    pub tls_cert_path: Option<PathBuf>,
155    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
156    #[deprecated(
157        since = "0.13.0",
158        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
159    )]
160    pub tls_key_path: Option<PathBuf>,
161    /// Optional authentication config. When `Some` and `enabled`, auth
162    /// is enforced on `/mcp`. `/healthz` is always open.
163    #[deprecated(
164        since = "0.13.0",
165        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
166    )]
167    pub auth: Option<AuthConfig>,
168    /// Optional RBAC policy. When present and enabled, tool calls are
169    /// checked against the policy after authentication.
170    #[deprecated(
171        since = "0.13.0",
172        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
173    )]
174    pub rbac: Option<Arc<RbacPolicy>>,
175    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
176    /// When empty and `public_url` is set, the origin is auto-derived from
177    /// the public URL. When both are empty, only requests with no Origin
178    /// header are accepted.
179    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
180    #[deprecated(
181        since = "0.13.0",
182        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
183    )]
184    pub allowed_origins: Vec<String>,
185    /// Maximum tool invocations per source IP per minute.
186    /// When set, enforced on every `tools/call` request.
187    #[deprecated(
188        since = "0.13.0",
189        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
190    )]
191    pub tool_rate_limit: Option<u32>,
192    /// Optional readiness probe for `/readyz`.
193    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
194    #[deprecated(
195        since = "0.13.0",
196        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
197    )]
198    pub readiness_check: Option<ReadinessCheck>,
199    /// Maximum request body size in bytes. Default: 1 MiB.
200    /// Protects against oversized payloads causing OOM.
201    #[deprecated(
202        since = "0.13.0",
203        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
204    )]
205    pub max_request_body: usize,
206    /// Request processing timeout. Default: 120s.
207    /// Requests exceeding this duration receive 408 Request Timeout.
208    #[deprecated(
209        since = "0.13.0",
210        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
211    )]
212    pub request_timeout: Duration,
213    /// Graceful shutdown timeout. Default: 30s.
214    /// After the shutdown signal, in-flight requests have this long to finish.
215    #[deprecated(
216        since = "0.13.0",
217        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
218    )]
219    pub shutdown_timeout: Duration,
220    /// Idle timeout for MCP sessions. Sessions with no activity for this
221    /// duration are closed automatically. Default: 20 minutes.
222    #[deprecated(
223        since = "0.13.0",
224        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
225    )]
226    pub session_idle_timeout: Duration,
227    /// Interval for SSE keep-alive pings. Prevents proxies and load
228    /// balancers from killing idle connections. Default: 15 seconds.
229    #[deprecated(
230        since = "0.13.0",
231        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
232    )]
233    pub sse_keep_alive: Duration,
234    /// Callback invoked once the server is built, delivering a
235    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
236    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
237    #[deprecated(
238        since = "0.13.0",
239        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
240    )]
241    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
242    /// Additional application-specific routes merged into the top-level
243    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
244    /// so the application is responsible for its own auth on them.
245    #[deprecated(
246        since = "0.13.0",
247        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
248    )]
249    pub extra_router: Option<axum::Router>,
250    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
251    /// When set, OAuth metadata endpoints advertise this URL instead of
252    /// the listen address. Required when binding `0.0.0.0` behind a
253    /// reverse proxy or inside a container.
254    #[deprecated(
255        since = "0.13.0",
256        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
257    )]
258    pub public_url: Option<String>,
259    /// Log inbound HTTP request headers at DEBUG level.
260    /// Sensitive values remain redacted.
261    #[deprecated(
262        since = "0.13.0",
263        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
264    )]
265    pub log_request_headers: bool,
266    /// Enable gzip/br response compression on MCP responses.
267    /// Defaults to `false` to preserve existing behaviour.
268    #[deprecated(
269        since = "0.13.0",
270        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
271    )]
272    pub compression_enabled: bool,
273    /// Minimum response body size (in bytes) before compression kicks in.
274    /// Only used when `compression_enabled` is true. Default: 1024.
275    #[deprecated(
276        since = "0.13.0",
277        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
278    )]
279    pub compression_min_size: u16,
280    /// Global cap on in-flight HTTP requests across the whole server.
281    /// When `Some`, requests over the cap receive 503 Service Unavailable
282    /// via `tower::load_shed`. Default: `None` (unlimited).
283    #[deprecated(
284        since = "0.13.0",
285        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
286    )]
287    pub max_concurrent_requests: Option<usize>,
288    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
289    /// configured and `enabled`. Default: `false`.
290    #[deprecated(
291        since = "0.13.0",
292        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
293    )]
294    pub admin_enabled: bool,
295    /// RBAC role required to access admin endpoints. Default: `"admin"`.
296    #[deprecated(
297        since = "0.13.0",
298        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
299    )]
300    pub admin_role: String,
301    /// Enable Prometheus metrics endpoint on a separate listener.
302    /// Requires the `metrics` crate feature.
303    #[cfg(feature = "metrics")]
304    #[deprecated(
305        since = "0.13.0",
306        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
307    )]
308    pub metrics_enabled: bool,
309    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
310    #[cfg(feature = "metrics")]
311    #[deprecated(
312        since = "0.13.0",
313        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
314    )]
315    pub metrics_bind: String,
316    /// Per-header overrides for the OWASP security headers emitted by
317    /// the global response middleware. See [`SecurityHeadersConfig`]
318    /// for the three-state semantic and validation rules.
319    #[deprecated(
320        since = "1.5.0",
321        note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
322    )]
323    pub security_headers: SecurityHeadersConfig,
324}
325
326/// Marker that wraps a value proven to satisfy its validation
327/// contract.
328///
329/// The only way to obtain `Validated<McpServerConfig>` is by calling
330/// [`McpServerConfig::validate`], which is the contract enforced at
331/// the type level by [`serve`] and [`serve_with_listener`]. The
332/// inner field is private, so downstream code cannot bypass
333/// validation by hand-constructing the wrapper.
334///
335/// Use [`Validated::as_inner`] for read-only borrowing. To mutate,
336/// recover the raw value with [`Validated::into_inner`] and
337/// re-validate.
338///
339/// # Example
340///
341/// ```no_run
342/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
343/// use rmcp::handler::server::ServerHandler;
344/// use rmcp::model::{ServerCapabilities, ServerInfo};
345///
346/// #[derive(Clone)]
347/// struct H;
348/// impl ServerHandler for H {
349///     fn get_info(&self) -> ServerInfo {
350///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
351///     }
352/// }
353///
354/// # async fn example() -> rmcp_server_kit::Result<()> {
355/// let config: Validated<McpServerConfig> =
356///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
357/// serve(config, || H).await
358/// # }
359/// ```
360///
361/// Forgetting `.validate()?` is a compile error:
362///
363/// ```compile_fail
364/// use rmcp_server_kit::transport::{McpServerConfig, serve};
365/// use rmcp::handler::server::ServerHandler;
366/// use rmcp::model::{ServerCapabilities, ServerInfo};
367///
368/// #[derive(Clone)]
369/// struct H;
370/// impl ServerHandler for H {
371///     fn get_info(&self) -> ServerInfo {
372///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
373///     }
374/// }
375///
376/// # async fn example() -> rmcp_server_kit::Result<()> {
377/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
378/// // Missing `.validate()?` -> mismatched types: expected
379/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
380/// serve(config, || H).await
381/// # }
382/// ```
383#[allow(
384    missing_debug_implementations,
385    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
386)]
387pub struct Validated<T>(T);
388
389impl<T> std::fmt::Debug for Validated<T> {
390    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391        f.debug_struct("Validated").finish_non_exhaustive()
392    }
393}
394
395impl<T> Validated<T> {
396    /// Borrow the inner value.
397    #[must_use]
398    pub fn as_inner(&self) -> &T {
399        &self.0
400    }
401
402    /// Recover the raw value, discarding the validation proof.
403    ///
404    /// Re-validate before re-using the value with [`serve`] or
405    /// [`serve_with_listener`].
406    #[must_use]
407    pub fn into_inner(self) -> T {
408        self.0
409    }
410}
411
412#[allow(
413    deprecated,
414    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
415)]
416impl McpServerConfig {
417    /// Create a new server configuration with the given bind address,
418    /// server name, and version. All other fields use safe defaults.
419    ///
420    /// Use the chainable `with_*` / `enable_*` builder methods to
421    /// customize. Call [`McpServerConfig::validate`] to obtain a
422    /// [`Validated<McpServerConfig>`] proof token, which is required by
423    /// [`serve`] and [`serve_with_listener`].
424    #[must_use]
425    pub fn new(
426        bind_addr: impl Into<String>,
427        name: impl Into<String>,
428        version: impl Into<String>,
429    ) -> Self {
430        Self {
431            bind_addr: bind_addr.into(),
432            name: name.into(),
433            version: version.into(),
434            tls_cert_path: None,
435            tls_key_path: None,
436            auth: None,
437            rbac: None,
438            allowed_origins: Vec::new(),
439            tool_rate_limit: None,
440            readiness_check: None,
441            max_request_body: 1024 * 1024,
442            request_timeout: Duration::from_mins(2),
443            shutdown_timeout: Duration::from_secs(30),
444            session_idle_timeout: Duration::from_mins(20),
445            sse_keep_alive: Duration::from_secs(15),
446            on_reload_ready: None,
447            extra_router: None,
448            public_url: None,
449            log_request_headers: false,
450            compression_enabled: false,
451            compression_min_size: 1024,
452            max_concurrent_requests: None,
453            admin_enabled: false,
454            admin_role: "admin".to_owned(),
455            #[cfg(feature = "metrics")]
456            metrics_enabled: false,
457            #[cfg(feature = "metrics")]
458            metrics_bind: "127.0.0.1:9090".into(),
459            security_headers: SecurityHeadersConfig::default(),
460        }
461    }
462
463    // ---------------------------------------------------------------
464    // Builder methods (fluent, consume + return self).
465    //
466    // Each method is `#[must_use]` because dropping the returned
467    // `McpServerConfig` discards the configuration change.
468    // ---------------------------------------------------------------
469
470    /// Attach an authentication configuration. Required for
471    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
472    #[must_use]
473    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
474        self.auth = Some(auth);
475        self
476    }
477
478    /// Override one or more of the OWASP security headers emitted on
479    /// every response. See [`SecurityHeadersConfig`] for the three-state
480    /// semantic (`None` = default, `Some("")` = omit, `Some(v)` =
481    /// override). Values are validated by [`Self::validate`].
482    #[must_use]
483    pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
484        self.security_headers = headers;
485        self
486    }
487
488    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
489    /// final port is only known after pre-binding an ephemeral listener
490    /// (tests, dynamic-port deployments).
491    #[must_use]
492    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
493        self.bind_addr = addr.into();
494        self
495    }
496
497    /// Attach an RBAC policy. Tool calls are checked against the policy
498    /// after authentication.
499    #[must_use]
500    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
501        self.rbac = Some(rbac);
502        self
503    }
504
505    /// Configure TLS by providing the certificate and private key paths
506    /// (PEM). Both must be readable at startup. Without this call, the
507    /// server runs plain HTTP.
508    #[must_use]
509    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
510        self.tls_cert_path = Some(cert_path.into());
511        self.tls_key_path = Some(key_path.into());
512        self
513    }
514
515    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
516    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
517    /// a container so OAuth metadata and auto-derived origins resolve correctly.
518    #[must_use]
519    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
520        self.public_url = Some(url.into());
521        self
522    }
523
524    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
525    /// When empty and [`with_public_url`](Self::with_public_url) is set,
526    /// the origin is auto-derived.
527    #[must_use]
528    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
529    where
530        I: IntoIterator<Item = S>,
531        S: Into<String>,
532    {
533        self.allowed_origins = origins.into_iter().map(Into::into).collect();
534        self
535    }
536
537    /// Merge an additional axum router at the top level. Routes added
538    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
539    /// for its own protection.
540    #[must_use]
541    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
542        self.extra_router = Some(router);
543        self
544    }
545
546    /// Install an async readiness probe for `/readyz`. Without this call,
547    /// `/readyz` mirrors `/healthz` (always 200 OK).
548    #[must_use]
549    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
550        self.readiness_check = Some(check);
551        self
552    }
553
554    /// Override the maximum request body (bytes). Must be `> 0`.
555    /// Default: 1 MiB.
556    #[must_use]
557    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
558        self.max_request_body = bytes;
559        self
560    }
561
562    /// Override the per-request processing timeout. Default: 2 minutes.
563    #[must_use]
564    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
565        self.request_timeout = timeout;
566        self
567    }
568
569    /// Override the graceful shutdown grace period. Default: 30 seconds.
570    #[must_use]
571    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
572        self.shutdown_timeout = timeout;
573        self
574    }
575
576    /// Override the MCP session idle timeout. Default: 20 minutes.
577    #[must_use]
578    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
579        self.session_idle_timeout = timeout;
580        self
581    }
582
583    /// Override the SSE keep-alive interval. Default: 15 seconds.
584    #[must_use]
585    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
586        self.sse_keep_alive = interval;
587        self
588    }
589
590    /// Cap the global number of in-flight HTTP requests via
591    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
592    /// Default: unlimited.
593    #[must_use]
594    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
595        self.max_concurrent_requests = Some(limit);
596        self
597    }
598
599    /// Cap tool invocations per source IP per minute. Enforced on every
600    /// `tools/call` request.
601    #[must_use]
602    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
603        self.tool_rate_limit = Some(per_minute);
604        self
605    }
606
607    /// Register a callback that receives the [`ReloadHandle`] after the
608    /// server is built. Use it to wire SIGHUP-style hot reloads of API
609    /// keys and RBAC policy.
610    #[must_use]
611    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
612    where
613        F: FnOnce(ReloadHandle) + Send + 'static,
614    {
615        self.on_reload_ready = Some(Box::new(callback));
616        self
617    }
618
619    /// Enable gzip/brotli response compression on MCP responses.
620    /// `min_size` is the smallest body size (bytes) eligible for
621    /// compression. Default min size: 1024.
622    #[must_use]
623    pub fn enable_compression(mut self, min_size: u16) -> Self {
624        self.compression_enabled = true;
625        self.compression_min_size = min_size;
626        self
627    }
628
629    /// Enable `/admin/*` diagnostic endpoints. Requires
630    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
631    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
632    /// role gate (default: `"admin"`).
633    #[must_use]
634    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
635        self.admin_enabled = true;
636        self.admin_role = role.into();
637        self
638    }
639
640    /// Log inbound HTTP request headers at DEBUG level. Sensitive
641    /// values remain redacted by the logging layer.
642    #[must_use]
643    pub fn enable_request_header_logging(mut self) -> Self {
644        self.log_request_headers = true;
645        self
646    }
647
648    /// Enable the Prometheus metrics listener on `bind` (e.g.
649    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
650    #[cfg(feature = "metrics")]
651    #[must_use]
652    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
653        self.metrics_enabled = true;
654        self.metrics_bind = bind.into();
655        self
656    }
657
658    /// Validate the configuration and consume `self`, returning a
659    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
660    /// and [`serve_with_listener`]. This is the only way to construct
661    /// `Validated<McpServerConfig>`, so the type system guarantees
662    /// validation has run before the server starts.
663    ///
664    /// Checks:
665    ///
666    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
667    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
668    ///    be unset.
669    /// 3. `bind_addr` must parse as a [`SocketAddr`].
670    /// 4. `public_url`, when set, must start with `http://` or `https://`.
671    /// 5. Each entry in `allowed_origins` must start with `http://` or
672    ///    `https://`.
673    /// 6. `max_request_body` must be greater than zero.
674    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
675    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
676    ///    `proxy.token_url`, `proxy.introspection_url`,
677    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
678    ///    and use the `https` scheme. Set
679    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
680    ///    targets (strongly discouraged in production - see the field-level
681    ///    docs for the threat model).
682    ///
683    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
684    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
685    ///
686    /// # Errors
687    ///
688    /// Returns [`McpxError::Config`] with a human-readable message on
689    /// the first validation failure.
690    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
691        self.check()?;
692        Ok(Validated(self))
693    }
694
695    /// Run the validation checks without consuming `self`. Used by
696    /// internal call sites (e.g. tests) that need to inspect a config
697    /// without taking ownership.
698    fn check(&self) -> Result<(), McpxError> {
699        // 1. admin <-> auth dependency. Mirrors the runtime check in
700        //    `build_app_router`: admin endpoints require an auth state,
701        //    which is built only when `auth` is `Some` *and* `enabled`.
702        if self.admin_enabled {
703            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
704            if !auth_enabled {
705                return Err(McpxError::Config(
706                    "admin_enabled=true requires auth to be configured and enabled".into(),
707                ));
708            }
709        }
710
711        // 2. TLS cert / key must be paired
712        match (&self.tls_cert_path, &self.tls_key_path) {
713            (Some(_), None) => {
714                return Err(McpxError::Config(
715                    "tls_cert_path is set but tls_key_path is missing".into(),
716                ));
717            }
718            (None, Some(_)) => {
719                return Err(McpxError::Config(
720                    "tls_key_path is set but tls_cert_path is missing".into(),
721                ));
722            }
723            _ => {}
724        }
725
726        // 3. bind_addr parses
727        if self.bind_addr.parse::<SocketAddr>().is_err() {
728            return Err(McpxError::Config(format!(
729                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
730                self.bind_addr
731            )));
732        }
733
734        // 4. public_url scheme
735        if let Some(ref url) = self.public_url
736            && !(url.starts_with("http://") || url.starts_with("https://"))
737        {
738            return Err(McpxError::Config(format!(
739                "public_url {url:?} must start with http:// or https://"
740            )));
741        }
742
743        // 5. allowed_origins scheme
744        for origin in &self.allowed_origins {
745            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
746                return Err(McpxError::Config(format!(
747                    "allowed_origins entry {origin:?} must start with http:// or https://"
748                )));
749            }
750        }
751
752        // 6. max_request_body > 0
753        if self.max_request_body == 0 {
754            return Err(McpxError::Config(
755                "max_request_body must be greater than zero".into(),
756            ));
757        }
758
759        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
760        #[cfg(feature = "oauth")]
761        if let Some(auth_cfg) = &self.auth
762            && let Some(oauth_cfg) = &auth_cfg.oauth
763        {
764            oauth_cfg.validate()?;
765        }
766
767        // 8. Security-header overrides parse as valid HTTP header values,
768        //    and HSTS does not smuggle in a `preload` directive.
769        validate_security_headers(&self.security_headers)?;
770
771        // 9. max_concurrent_requests must be > 0 when set. Zero would
772        //    deadlock the global concurrency limiter and reject every
773        //    request. Mirrors the TOML-side check in `src/config.rs`.
774        if let Some(0) = self.max_concurrent_requests {
775            return Err(McpxError::Config(
776                "max_concurrent_requests must be greater than zero when set".into(),
777            ));
778        }
779
780        // 10. Auth rate-limit `max_tracked_keys` must be > 0. A zero cap
781        //     would force `BoundedKeyedLimiter` to evict on every insert
782        //     and effectively disable rate limiting.
783        if let Some(auth_cfg) = &self.auth
784            && let Some(rl) = &auth_cfg.rate_limit
785            && rl.max_tracked_keys == 0
786        {
787            return Err(McpxError::Config(
788                "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
789            ));
790        }
791
792        Ok(())
793    }
794}
795
796/// Handle for hot-reloading server configuration without restart.
797///
798/// Obtained via [`McpServerConfig::on_reload_ready`].
799/// All swap operations are lock-free and wait-free -- in-flight requests
800/// finish with the old values while new requests see the update immediately.
801#[allow(
802    missing_debug_implementations,
803    reason = "contains Arc<AuthState> with non-Debug fields"
804)]
805pub struct ReloadHandle {
806    auth: Option<Arc<AuthState>>,
807    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
808    crl_set: Option<Arc<CrlSet>>,
809}
810
811impl ReloadHandle {
812    /// Atomically replace the API key list used by the auth middleware.
813    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
814        if let Some(ref auth) = self.auth {
815            auth.reload_keys(keys);
816        }
817    }
818
819    /// Atomically replace the RBAC policy used by the RBAC middleware.
820    pub fn reload_rbac(&self, policy: RbacPolicy) {
821        if let Some(ref rbac) = self.rbac {
822            rbac.store(Arc::new(policy));
823            tracing::info!("RBAC policy reloaded");
824        }
825    }
826
827    /// Force an immediate refresh of all cached mTLS CRLs.
828    ///
829    /// # Errors
830    ///
831    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
832    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
833        let Some(ref crl_set) = self.crl_set else {
834            return Err(McpxError::Config(
835                "CRL refresh requested but mTLS CRL support is not configured".into(),
836            ));
837        };
838
839        crl_set.force_refresh().await
840    }
841}
842
843/// Generic MCP HTTP server.
844///
845/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
846/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
847/// with TLS (rustls). Optionally supports mTLS client certificate auth.
848///
849/// # Errors
850///
851/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
852/// or the server fails.
853// NOTE: cognitive complexity reduced from 111/25 to 83/25 by
854// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
855// Remaining flow is a linear router builder: middleware layering, feature-
856// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
857// would require threading many `&mut Router` helpers and hurt readability
858// of the layer order (which is security-relevant and must stay visible).
859#[allow(
860    clippy::too_many_lines,
861    clippy::cognitive_complexity,
862    reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
863)]
864/// Internal bundle of values produced by [`build_app_router`] and
865/// consumed by [`serve`] / [`serve_with_listener`] when driving the
866/// HTTP listener.
867struct AppRunParams {
868    /// TLS cert/key paths when TLS is configured.
869    tls_paths: Option<(PathBuf, PathBuf)>,
870    /// mTLS configuration when mutual-TLS auth is enabled.
871    mtls_config: Option<MtlsConfig>,
872    /// Graceful shutdown drain window.
873    shutdown_timeout: Duration,
874    /// Shared auth state used by hot-reload callbacks.
875    auth_state: Option<Arc<AuthState>>,
876    /// Hot-reloadable RBAC state used by reload callbacks.
877    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
878    /// Optional callback that receives the final [`ReloadHandle`].
879    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
880    /// Server-internal cancellation token. Cancelled by [`run_server`]
881    /// once the shutdown trigger fires (so rmcp's child token also
882    /// fires, terminating in-flight MCP sessions).
883    ct: CancellationToken,
884    /// `"http"` or `"https"` -- used only for boot-time logging.
885    scheme: &'static str,
886    /// Server name -- used only for boot-time logging.
887    name: String,
888}
889
890/// Build the full application axum [`axum::Router`] (MCP route +
891/// middleware stack + admin + OAuth + health endpoints + security
892/// headers + CORS + compression + concurrency limit + origin check)
893/// and the [`AppRunParams`] needed to drive it.
894///
895/// This is the shared core of [`serve`] and [`serve_with_listener`].
896/// It performs *no* network I/O: callers are responsible for binding
897/// (or accepting a pre-bound) [`TcpListener`] and invoking
898/// [`run_server`].
899#[allow(
900    clippy::cognitive_complexity,
901    reason = "router assembly is intrinsically sequential; splitting harms readability"
902)]
903#[allow(
904    deprecated,
905    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
906)]
907fn build_app_router<H, F>(
908    mut config: McpServerConfig,
909    handler_factory: F,
910) -> anyhow::Result<(axum::Router, AppRunParams)>
911where
912    H: ServerHandler + 'static,
913    F: Fn() -> H + Send + Sync + Clone + 'static,
914{
915    let ct = CancellationToken::new();
916
917    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
918    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
919
920    let mcp_service = StreamableHttpService::new(
921        move || Ok(handler_factory()),
922        {
923            let mut mgr = LocalSessionManager::default();
924            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
925            mgr.into()
926        },
927        StreamableHttpServerConfig::default()
928            .with_allowed_hosts(allowed_hosts)
929            .with_sse_keep_alive(Some(config.sse_keep_alive))
930            .with_cancellation_token(ct.child_token()),
931    );
932
933    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
934    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
935
936    // Build auth state eagerly when auth is configured so we can wire both
937    // the auth middleware *and* the optional admin router against the same
938    // state. The middleware itself is installed further down in layer order.
939    let auth_state: Option<Arc<AuthState>> = match config.auth {
940        Some(ref auth_config) if auth_config.enabled => {
941            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
942            let pre_auth_limiter = auth_config
943                .rate_limit
944                .as_ref()
945                .map(crate::auth::build_pre_auth_limiter);
946
947            #[cfg(feature = "oauth")]
948            let jwks_cache = auth_config
949                .oauth
950                .as_ref()
951                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
952                .transpose()
953                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
954
955            Some(Arc::new(AuthState {
956                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
957                rate_limiter,
958                pre_auth_limiter,
959                #[cfg(feature = "oauth")]
960                jwks_cache,
961                seen_identities: crate::auth::SeenIdentitySet::new(),
962                counters: crate::auth::AuthCounters::default(),
963            }))
964        }
965        _ => None,
966    };
967
968    // Build the RBAC policy swap early so the admin router and the later
969    // RBAC middleware layer share the same hot-reloadable state.
970    let rbac_swap = Arc::new(ArcSwap::new(
971        config
972            .rbac
973            .clone()
974            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
975    ));
976
977    // Optional /admin/* diagnostic routes. Merged BEFORE the
978    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
979    if config.admin_enabled {
980        let Some(ref auth_state_ref) = auth_state else {
981            return Err(anyhow::anyhow!(
982                "admin_enabled=true requires auth to be configured and enabled"
983            ));
984        };
985        let admin_state = crate::admin::AdminState {
986            started_at: std::time::Instant::now(),
987            name: config.name.clone(),
988            version: config.version.clone(),
989            auth: Some(Arc::clone(auth_state_ref)),
990            rbac: Arc::clone(&rbac_swap),
991        };
992        let admin_cfg = crate::admin::AdminConfig {
993            role: config.admin_role.clone(),
994        };
995        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
996        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
997    }
998
999    // ----- Middleware order (CRITICAL: read carefully) ------------------
1000    //
1001    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
1002    // added is the OUTERMOST (runs first on a request). To achieve a
1003    // request-time flow of:
1004    //
1005    //   body-limit -> timeout -> auth -> rbac -> handler
1006    //
1007    // we add layers in the REVERSE order:
1008    //
1009    //   1. RBAC               (innermost, runs last before handler)
1010    //   2. auth               (parses identity, sets extension for RBAC)
1011    //   3. timeout            (bounds total request time)
1012    //   4. body-limit         (outermost on /mcp; caps payload before
1013    //                          anything else reads/buffers it)
1014    //
1015    // Origin validation is installed on the OUTER router (after the
1016    // /mcp router is merged in), so it also protects /healthz, /readyz,
1017    // /version, and any OAuth proxy endpoints.
1018    //
1019    // Rationale:
1020    // - Body-limit must be outermost on /mcp so RBAC (which reads the
1021    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
1022    // - Auth must run before RBAC because RBAC consumes
1023    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
1024    //   policy.
1025    // - Origin runs before auth so we reject cross-origin requests
1026    //   without spending Argon2 cycles on unauthenticated callers.
1027
1028    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
1029    // Always installed: even when RBAC is disabled, tool rate limiting may
1030    // be active (MCP spec: servers MUST rate limit tool invocations).
1031    {
1032        let tool_limiter: Option<Arc<ToolRateLimiter>> =
1033            config.tool_rate_limit.map(build_tool_rate_limiter);
1034
1035        if rbac_swap.load().is_enabled() {
1036            tracing::info!("RBAC enforcement enabled on /mcp");
1037        }
1038        if let Some(limit) = config.tool_rate_limit {
1039            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1040        }
1041
1042        let rbac_for_mw = Arc::clone(&rbac_swap);
1043        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1044            let p = rbac_for_mw.load_full();
1045            let tl = tool_limiter.clone();
1046            rbac_middleware(p, tl, req, next)
1047        }));
1048    }
1049
1050    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
1051    if let Some(ref auth_config) = config.auth
1052        && auth_config.enabled
1053    {
1054        let Some(ref state) = auth_state else {
1055            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1056        };
1057
1058        let methods: Vec<&str> = [
1059            auth_config.mtls.is_some().then_some("mTLS"),
1060            (!auth_config.api_keys.is_empty()).then_some("bearer"),
1061            #[cfg(feature = "oauth")]
1062            auth_config.oauth.is_some().then_some("oauth-jwt"),
1063        ]
1064        .into_iter()
1065        .flatten()
1066        .collect();
1067
1068        tracing::info!(
1069            methods = %methods.join(", "),
1070            api_keys = auth_config.api_keys.len(),
1071            "auth enabled on /mcp"
1072        );
1073
1074        let state_for_mw = Arc::clone(state);
1075        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1076            let s = Arc::clone(&state_for_mw);
1077            auth_middleware(s, req, next)
1078        }));
1079    }
1080
1081    // [3] Request timeout (returns 408 on expiry). Bounds total request
1082    // duration including auth + handler.
1083    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1084        axum::http::StatusCode::REQUEST_TIMEOUT,
1085        config.request_timeout,
1086    ));
1087
1088    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
1089    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
1090    // attempts to buffer or parse the body.
1091    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1092        config.max_request_body,
1093    ));
1094
1095    // Compute the effective allowed-origins list for the outer
1096    // origin-check layer (installed on the merged router below). When
1097    // `allowed_origins` is empty but `public_url` is set, auto-derive
1098    // the origin from the public URL so MCP clients (e.g. Claude Code)
1099    // that send `Origin: <server-url>` are accepted without explicit
1100    // config.
1101    let mut effective_origins = config.allowed_origins.clone();
1102    if effective_origins.is_empty()
1103        && let Some(ref url) = config.public_url
1104    {
1105        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1106        // Strip any path/query from the public URL. Offsets come from
1107        // `find`, so they are char-boundary-aligned; `get(..)` keeps that
1108        // machine-checked (a violation degrades to an empty slice).
1109        if let Some(scheme_end) = url.find("://") {
1110            let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1111            let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1112            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1113            let host = after_scheme.get(..host_end).unwrap_or_default();
1114            let origin = format!("{scheme_with_sep}{host}");
1115            tracing::info!(
1116                %origin,
1117                "auto-derived allowed origin from public_url"
1118            );
1119            effective_origins.push(origin);
1120        }
1121    }
1122    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1123    let cors_origins = Arc::clone(&allowed_origins);
1124    let log_request_headers = config.log_request_headers;
1125
1126    let readyz_route = if let Some(check) = config.readiness_check.take() {
1127        axum::routing::get(move || readyz(Arc::clone(&check)))
1128    } else {
1129        axum::routing::get(healthz)
1130    };
1131
1132    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1133    let mut router = axum::Router::new()
1134        .route("/healthz", axum::routing::get(healthz))
1135        .route("/readyz", readyz_route)
1136        .route(
1137            "/version",
1138            axum::routing::get({
1139                // Pre-serialize the version payload once at router-build
1140                // time. The handler then serves a cheap `Arc::clone` of the
1141                // immutable bytes per request, avoiding `serde_json::Value`
1142                // allocation + serialization on every `/version` hit.
1143                let payload_bytes: Arc<[u8]> =
1144                    serialize_version_payload(&config.name, &config.version);
1145                move || {
1146                    let p = Arc::clone(&payload_bytes);
1147                    async move {
1148                        (
1149                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1150                            p.to_vec(),
1151                        )
1152                    }
1153                }
1154            }),
1155        )
1156        .merge(mcp_router);
1157
1158    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1159    if let Some(extra) = config.extra_router.take() {
1160        router = router.merge(extra);
1161    }
1162
1163    // RFC 9728: Protected Resource Metadata endpoint.
1164    // When OAuth is configured, serve full metadata with authorization_servers.
1165    // Otherwise, serve a minimal document with just the resource URL and no
1166    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1167    // that the server exists but does NOT require OAuth authentication,
1168    // preventing them from gating the connection behind a broken auth flow.
1169    let server_url = if let Some(ref url) = config.public_url {
1170        url.trim_end_matches('/').to_owned()
1171    } else {
1172        let prm_scheme = if config.tls_cert_path.is_some() {
1173            "https"
1174        } else {
1175            "http"
1176        };
1177        format!("{prm_scheme}://{}", config.bind_addr)
1178    };
1179    let resource_url = format!("{server_url}/mcp");
1180
1181    #[cfg(feature = "oauth")]
1182    let prm_metadata = if let Some(ref auth_config) = config.auth
1183        && let Some(ref oauth_config) = auth_config.oauth
1184    {
1185        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1186    } else {
1187        serde_json::json!({ "resource": resource_url })
1188    };
1189    #[cfg(not(feature = "oauth"))]
1190    let prm_metadata = serde_json::json!({ "resource": resource_url });
1191
1192    router = router.route(
1193        "/.well-known/oauth-protected-resource",
1194        axum::routing::get(move || {
1195            let m = prm_metadata.clone();
1196            async move { axum::Json(m) }
1197        }),
1198    );
1199
1200    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1201    // /authorize, /token, /register, and authorization server metadata so
1202    // MCP clients can perform Authorization Code + PKCE against the upstream
1203    // IdP (e.g. Keycloak) transparently.
1204    #[cfg(feature = "oauth")]
1205    if let Some(ref auth_config) = config.auth
1206        && let Some(ref oauth_config) = auth_config.oauth
1207        && oauth_config.proxy.is_some()
1208    {
1209        router =
1210            install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1211    }
1212
1213    // OWASP security response headers (applied to all responses).
1214    // HSTS is conditional on TLS being configured.
1215    let is_tls = config.tls_cert_path.is_some();
1216    let security_headers_cfg = Arc::new(config.security_headers.clone());
1217    router = router.layer(axum::middleware::from_fn(move |req, next| {
1218        let cfg = Arc::clone(&security_headers_cfg);
1219        security_headers_middleware(is_tls, cfg, req, next)
1220    }));
1221
1222    // CORS preflight layer (required for browser-based MCP clients).
1223    // Uses the same effective origins as the origin check middleware
1224    // (including auto-derived origin from public_url).
1225    if !cors_origins.is_empty() {
1226        let cors = tower_http::cors::CorsLayer::new()
1227            .allow_origin(
1228                cors_origins
1229                    .iter()
1230                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1231                    .collect::<Vec<_>>(),
1232            )
1233            .allow_methods([
1234                axum::http::Method::GET,
1235                axum::http::Method::POST,
1236                axum::http::Method::OPTIONS,
1237            ])
1238            .allow_headers([
1239                axum::http::header::CONTENT_TYPE,
1240                axum::http::header::AUTHORIZATION,
1241            ]);
1242        router = router.layer(cors);
1243    }
1244
1245    // Optional response compression (gzip + brotli). Skips small bodies
1246    // to avoid overhead. Applied after CORS so preflight responses remain
1247    // uncompressed.
1248    if config.compression_enabled {
1249        use tower_http::compression::Predicate as _;
1250        let predicate = tower_http::compression::DefaultPredicate::new().and(
1251            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1252        );
1253        router = router.layer(
1254            tower_http::compression::CompressionLayer::new()
1255                .gzip(true)
1256                .br(true)
1257                .compress_when(predicate),
1258        );
1259        tracing::info!(
1260            min_size = config.compression_min_size,
1261            "response compression enabled (gzip, br)"
1262        );
1263    }
1264
1265    // Optional global concurrency cap. `load_shed` converts the
1266    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1267    if let Some(max) = config.max_concurrent_requests {
1268        let overload_handler = tower::ServiceBuilder::new()
1269            .layer(axum::error_handling::HandleErrorLayer::new(
1270                |_err: tower::BoxError| async {
1271                    (
1272                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1273                        axum::Json(serde_json::json!({
1274                            "error": "overloaded",
1275                            "error_description": "server is at capacity, retry later"
1276                        })),
1277                    )
1278                },
1279            ))
1280            .layer(tower::load_shed::LoadShedLayer::new())
1281            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1282        router = router.layer(overload_handler);
1283        tracing::info!(max, "global concurrency limit enabled");
1284    }
1285
1286    // JSON fallback for unmatched routes. Without this, axum returns
1287    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1288    // when they probe OAuth endpoints like /authorize or /token.
1289    router = router.fallback(|| async {
1290        (
1291            axum::http::StatusCode::NOT_FOUND,
1292            axum::Json(serde_json::json!({
1293                "error": "not_found",
1294                "error_description": "The requested endpoint does not exist"
1295            })),
1296        )
1297    });
1298
1299    // Prometheus metrics: recording middleware + separate listener.
1300    #[cfg(feature = "metrics")]
1301    if config.metrics_enabled {
1302        let metrics = Arc::new(
1303            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1304        );
1305        let m = Arc::clone(&metrics);
1306        router = router.layer(axum::middleware::from_fn(
1307            move |req: Request<Body>, next: Next| {
1308                let m = Arc::clone(&m);
1309                metrics_middleware(m, req, next)
1310            },
1311        ));
1312        let metrics_bind = config.metrics_bind.clone();
1313        let metrics_shutdown = ct.clone();
1314        tokio::spawn(async move {
1315            if let Err(e) =
1316                crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1317            {
1318                tracing::error!("metrics listener failed: {e}");
1319            }
1320        });
1321    }
1322
1323    // Origin validation layer (MCP spec: servers MUST validate the
1324    // Origin header to prevent DNS rebinding attacks). Installed as the
1325    // OUTERMOST layer on the OUTER router so it protects ALL routes
1326    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1327    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1328    // reject cross-origin attackers without spending Argon2 cycles.
1329    //
1330    // Origin-less requests (e.g. server-to-server probes, curl, native
1331    // MCP clients) are permitted; only requests with an Origin header
1332    // that does not match `effective_origins` are rejected.
1333    router = router.layer(axum::middleware::from_fn(move |req, next| {
1334        let origins = Arc::clone(&allowed_origins);
1335        origin_check_middleware(origins, log_request_headers, req, next)
1336    }));
1337
1338    let scheme = if config.tls_cert_path.is_some() {
1339        "https"
1340    } else {
1341        "http"
1342    };
1343
1344    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1345        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1346        _ => None,
1347    };
1348    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1349
1350    Ok((
1351        router,
1352        AppRunParams {
1353            tls_paths,
1354            mtls_config,
1355            shutdown_timeout: config.shutdown_timeout,
1356            auth_state,
1357            rbac_swap,
1358            on_reload_ready: config.on_reload_ready.take(),
1359            ct,
1360            scheme,
1361            name: config.name.clone(),
1362        },
1363    ))
1364}
1365
1366/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1367/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1368///
1369/// This is the standard entry point for production deployments. For
1370/// deterministic shutdown control (e.g. integration tests), see
1371/// [`serve_with_listener`].
1372///
1373/// The configuration must be validated first via
1374/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1375/// token. This typestate guarantees, at compile time, that the server
1376/// never starts with an invalid configuration.
1377///
1378/// # Errors
1379///
1380/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1381/// fails, or if the underlying axum server returns an error.
1382pub async fn serve<H, F>(
1383    config: Validated<McpServerConfig>,
1384    handler_factory: F,
1385) -> Result<(), McpxError>
1386where
1387    H: ServerHandler + 'static,
1388    F: Fn() -> H + Send + Sync + Clone + 'static,
1389{
1390    let config = config.into_inner();
1391    #[allow(
1392        deprecated,
1393        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1394    )]
1395    let bind_addr = config.bind_addr.clone();
1396    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1397
1398    let listener = TcpListener::bind(&bind_addr)
1399        .await
1400        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1401    log_listening(&params.name, params.scheme, &bind_addr);
1402
1403    run_server(
1404        router,
1405        listener,
1406        params.tls_paths,
1407        params.mtls_config,
1408        params.shutdown_timeout,
1409        params.auth_state,
1410        params.rbac_swap,
1411        params.on_reload_ready,
1412        params.ct,
1413    )
1414    .await
1415    .map_err(anyhow_to_startup)
1416}
1417
1418/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1419/// readiness signalling and external shutdown control.
1420///
1421/// This variant is intended for **deterministic integration tests** and
1422/// for embedders that need to bind the listening socket themselves
1423/// (e.g. systemd socket activation). Compared to [`serve`]:
1424///
1425/// * The caller passes a `TcpListener` that is already bound. This
1426///   eliminates the bind race in tests that previously required
1427///   poll-the-`/healthz`-loop start-up detection.
1428/// * `ready_tx`, when `Some`, receives the socket's
1429///   [`SocketAddr`] *after* the router is built and immediately before
1430///   the server starts accepting connections. Tests can `await` the
1431///   matching `oneshot::Receiver` to know exactly when it is safe to
1432///   issue requests.
1433/// * `shutdown`, when `Some`, gives the caller a
1434///   [`CancellationToken`] that triggers the same graceful-shutdown
1435///   path as a real OS signal. This avoids cross-platform issues with
1436///   sending real `SIGTERM` from tests on Windows.
1437///
1438/// All three optional parameters degrade gracefully: if `ready_tx` is
1439/// `None`, no signal is sent; if `shutdown` is `None`, the server only
1440/// stops on an OS signal (just like [`serve`]).
1441///
1442/// # Errors
1443///
1444/// Returns [`McpxError::Startup`] if router construction fails, if reading
1445/// the listener's `local_addr()` fails, or if the underlying axum
1446/// server returns an error.
1447pub async fn serve_with_listener<H, F>(
1448    listener: TcpListener,
1449    config: Validated<McpServerConfig>,
1450    handler_factory: F,
1451    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1452    shutdown: Option<CancellationToken>,
1453) -> Result<(), McpxError>
1454where
1455    H: ServerHandler + 'static,
1456    F: Fn() -> H + Send + Sync + Clone + 'static,
1457{
1458    let config = config.into_inner();
1459    let local_addr = listener
1460        .local_addr()
1461        .map_err(|e| io_to_startup("listener.local_addr", e))?;
1462    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1463
1464    log_listening(&params.name, params.scheme, &local_addr.to_string());
1465
1466    // Forward external shutdown into the server-internal cancellation
1467    // token so `run_server`'s shutdown trigger picks it up alongside
1468    // any real OS signal.
1469    if let Some(external) = shutdown {
1470        let internal = params.ct.clone();
1471        tokio::spawn(async move {
1472            external.cancelled().await;
1473            internal.cancel();
1474        });
1475    }
1476
1477    // Signal readiness *after* the router is fully built and external
1478    // shutdown is wired, but *before* run_server takes ownership of
1479    // the listener. The receiver can immediately issue requests.
1480    if let Some(tx) = ready_tx {
1481        // Receiver may have been dropped (test gave up). That's fine.
1482        let _ = tx.send(local_addr);
1483    }
1484
1485    run_server(
1486        router,
1487        listener,
1488        params.tls_paths,
1489        params.mtls_config,
1490        params.shutdown_timeout,
1491        params.auth_state,
1492        params.rbac_swap,
1493        params.on_reload_ready,
1494        params.ct,
1495    )
1496    .await
1497    .map_err(anyhow_to_startup)
1498}
1499
1500/// Emit the standard "listening on …" log lines used by both
1501/// [`serve`] and [`serve_with_listener`].
1502#[allow(
1503    clippy::cognitive_complexity,
1504    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1505)]
1506fn log_listening(name: &str, scheme: &str, addr: &str) {
1507    tracing::info!("{name} listening on {addr}");
1508    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
1509    tracing::info!("  Health check: {scheme}://{addr}/healthz");
1510    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
1511}
1512
1513/// Drive the chosen axum server variant (TLS or plain) with a graceful
1514/// shutdown window. Consumes the router and listener.
1515///
1516/// # Shutdown semantics
1517///
1518/// A single shutdown trigger (the FIRST of: OS signal via
1519/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
1520///
1521/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
1522///    accepting new connections and waits for in-flight requests to
1523///    drain;
1524/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
1525///    drainage exceeds `shutdown_timeout`.
1526///
1527/// Previously this function awaited `shutdown_signal()` independently
1528/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
1529/// resolves once per future and consumes one signal, the force-exit
1530/// timer was tied to a SECOND signal (a second SIGTERM the operator
1531/// would never send). Under a single SIGTERM the graceful drain could
1532/// hang indefinitely. The current implementation derives both branches
1533/// from a single shared trigger so the timeout race is anchored to the
1534/// FIRST (and only) signal.
1535#[allow(
1536    clippy::too_many_arguments,
1537    clippy::cognitive_complexity,
1538    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1539)]
1540async fn run_server(
1541    router: axum::Router,
1542    listener: TcpListener,
1543    tls_paths: Option<(PathBuf, PathBuf)>,
1544    mtls_config: Option<MtlsConfig>,
1545    shutdown_timeout: Duration,
1546    auth_state: Option<Arc<AuthState>>,
1547    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1548    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1549    ct: CancellationToken,
1550) -> anyhow::Result<()> {
1551    // `shutdown_trigger` fires when the FIRST source resolves: either
1552    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
1553    // (which the test harness uses for deterministic shutdown).
1554    let shutdown_trigger = CancellationToken::new();
1555    {
1556        let trigger = shutdown_trigger.clone();
1557        let parent = ct.clone();
1558        tokio::spawn(async move {
1559            tokio::select! {
1560                () = shutdown_signal() => {}
1561                () = parent.cancelled() => {}
1562            }
1563            trigger.cancel();
1564        });
1565    }
1566
1567    let graceful = {
1568        let trigger = shutdown_trigger.clone();
1569        let ct = ct.clone();
1570        async move {
1571            trigger.cancelled().await;
1572            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1573            ct.cancel();
1574        }
1575    };
1576
1577    let force_exit_timer = {
1578        let trigger = shutdown_trigger.clone();
1579        async move {
1580            trigger.cancelled().await;
1581            tokio::time::sleep(shutdown_timeout).await;
1582        }
1583    };
1584
1585    if let Some((cert_path, key_path)) = tls_paths {
1586        let crl_set = if let Some(mtls) = mtls_config.as_ref()
1587            && mtls.crl_enabled
1588        {
1589            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1590            let (crl_set, discover_rx) =
1591                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1592                    .await
1593                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1594            tokio::spawn(mtls_revocation::run_crl_refresher(
1595                Arc::clone(&crl_set),
1596                discover_rx,
1597                ct.clone(),
1598            ));
1599            Some(crl_set)
1600        } else {
1601            None
1602        };
1603
1604        if let Some(cb) = on_reload_ready.take() {
1605            cb(ReloadHandle {
1606                auth: auth_state.clone(),
1607                rbac: Some(Arc::clone(&rbac_swap)),
1608                crl_set: crl_set.clone(),
1609            });
1610        }
1611
1612        let tls_listener = TlsListener::new(
1613            listener,
1614            &cert_path,
1615            &key_path,
1616            mtls_config.as_ref(),
1617            crl_set,
1618            TLS_HANDSHAKE_TIMEOUT,
1619        )?;
1620        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1621        tokio::select! {
1622            result = axum::serve(tls_listener, make_svc)
1623                .with_graceful_shutdown(graceful) => { result?; }
1624            () = force_exit_timer => {
1625                tracing::warn!("shutdown timeout exceeded, forcing exit");
1626            }
1627        }
1628    } else {
1629        if let Some(cb) = on_reload_ready.take() {
1630            cb(ReloadHandle {
1631                auth: auth_state,
1632                rbac: Some(rbac_swap),
1633                crl_set: None,
1634            });
1635        }
1636
1637        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1638        tokio::select! {
1639            result = axum::serve(listener, make_svc)
1640                .with_graceful_shutdown(graceful) => { result?; }
1641            () = force_exit_timer => {
1642                tracing::warn!("shutdown timeout exceeded, forcing exit");
1643            }
1644        }
1645    }
1646
1647    Ok(())
1648}
1649
1650/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
1651/// `/register`, and authorization server metadata) on `router`. The
1652/// caller must ensure `oauth_config.proxy` is `Some`.
1653///
1654/// # Errors
1655///
1656/// Returns [`McpxError::Startup`] if the shared
1657/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
1658#[cfg(feature = "oauth")]
1659fn install_oauth_proxy_routes(
1660    router: axum::Router,
1661    server_url: &str,
1662    oauth_config: &crate::oauth::OAuthConfig,
1663    auth_state: Option<&Arc<AuthState>>,
1664) -> Result<axum::Router, McpxError> {
1665    let Some(ref proxy) = oauth_config.proxy else {
1666        return Ok(router);
1667    };
1668
1669    // Single shared HTTP client for all proxy endpoints. Cloning is
1670    // cheap (refcounted) and shares the underlying connection pool.
1671    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1672
1673    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1674    let router = router.route(
1675        "/.well-known/oauth-authorization-server",
1676        axum::routing::get(move || {
1677            let m = asm.clone();
1678            async move { axum::Json(m) }
1679        }),
1680    );
1681
1682    let proxy_authorize = proxy.clone();
1683    let router = router.route(
1684        "/authorize",
1685        axum::routing::get(
1686            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1687                let p = proxy_authorize.clone();
1688                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1689            },
1690        ),
1691    );
1692
1693    let proxy_token = proxy.clone();
1694    let token_http = http.clone();
1695    let router = router.route(
1696        "/token",
1697        axum::routing::post(move |body: String| {
1698            let p = proxy_token.clone();
1699            let h = token_http.clone();
1700            async move { crate::oauth::handle_token(&h, &p, &body).await }
1701        })
1702        .layer(axum::middleware::from_fn(
1703            oauth_token_cache_headers_middleware,
1704        )),
1705    );
1706
1707    let proxy_register = proxy.clone();
1708    let router = router.route(
1709        "/register",
1710        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1711            let p = proxy_register;
1712            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1713        })
1714        .layer(axum::middleware::from_fn(
1715            oauth_token_cache_headers_middleware,
1716        )),
1717    );
1718
1719    let admin_routes_enabled = proxy.expose_admin_endpoints
1720        && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1721    if proxy.expose_admin_endpoints
1722        && !proxy.require_auth_on_admin_endpoints
1723        && proxy.allow_unauthenticated_admin_endpoints
1724    {
1725        // M3 escape-hatch in effect: validate() let this through because
1726        // the operator explicitly opted in. Surface it loudly at startup
1727        // so the choice is auditable in logs.
1728        tracing::warn!(
1729            "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1730             allow_unauthenticated_admin_endpoints opt-out; ensure an \
1731             authenticated reverse proxy fronts these routes"
1732        );
1733    }
1734
1735    let admin_router = if admin_routes_enabled {
1736        build_oauth_admin_router(proxy, http, auth_state)?
1737    } else {
1738        axum::Router::new()
1739    };
1740
1741    let router = router.merge(admin_router);
1742
1743    tracing::info!(
1744        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1745        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1746        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1747    );
1748    Ok(router)
1749}
1750
1751/// Build the optional `/introspect` + `/revoke` admin sub-router.
1752///
1753/// Layered with [`oauth_token_cache_headers_middleware`] so RFC 6749 §5.1
1754/// / RFC 6750 §5.4 cache headers are emitted, and conditionally with the
1755/// auth middleware when `proxy.require_auth_on_admin_endpoints` is set.
1756#[cfg(feature = "oauth")]
1757fn build_oauth_admin_router(
1758    proxy: &crate::oauth::OAuthProxyConfig,
1759    http: crate::oauth::OauthHttpClient,
1760    auth_state: Option<&Arc<AuthState>>,
1761) -> Result<axum::Router, McpxError> {
1762    let mut admin_router = axum::Router::new();
1763    if proxy.introspection_url.is_some() {
1764        let proxy_introspect = proxy.clone();
1765        let introspect_http = http.clone();
1766        admin_router = admin_router.route(
1767            "/introspect",
1768            axum::routing::post(move |body: String| {
1769                let p = proxy_introspect.clone();
1770                let h = introspect_http.clone();
1771                async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1772            }),
1773        );
1774    }
1775    if proxy.revocation_url.is_some() {
1776        let proxy_revoke = proxy.clone();
1777        let revoke_http = http;
1778        admin_router = admin_router.route(
1779            "/revoke",
1780            axum::routing::post(move |body: String| {
1781                let p = proxy_revoke.clone();
1782                let h = revoke_http.clone();
1783                async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1784            }),
1785        );
1786    }
1787
1788    let admin_router = admin_router.layer(axum::middleware::from_fn(
1789        oauth_token_cache_headers_middleware,
1790    ));
1791
1792    if proxy.require_auth_on_admin_endpoints {
1793        let Some(state) = auth_state else {
1794            return Err(McpxError::Startup(
1795                "oauth proxy admin endpoints require auth state".into(),
1796            ));
1797        };
1798        let state_for_mw = Arc::clone(state);
1799        Ok(
1800            admin_router.layer(axum::middleware::from_fn(move |req, next| {
1801                let s = Arc::clone(&state_for_mw);
1802                auth_middleware(s, req, next)
1803            })),
1804        )
1805    } else {
1806        Ok(admin_router)
1807    }
1808}
1809
1810/// Build the host allow-list for rmcp's DNS rebinding protection.
1811///
1812/// Includes loopback hosts by default, then augments with host/authority
1813/// derived from `public_url` and the server bind address.
1814fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1815    let mut hosts = vec![
1816        "localhost".to_owned(),
1817        "127.0.0.1".to_owned(),
1818        "::1".to_owned(),
1819    ];
1820
1821    if let Some(url) = public_url
1822        && let Ok(uri) = url.parse::<axum::http::Uri>()
1823        && let Some(authority) = uri.authority()
1824    {
1825        let host = authority.host().to_owned();
1826        if !hosts.iter().any(|h| h == &host) {
1827            hosts.push(host);
1828        }
1829
1830        let authority = authority.as_str().to_owned();
1831        if !hosts.iter().any(|h| h == &authority) {
1832            hosts.push(authority);
1833        }
1834    }
1835
1836    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1837        && let Some(authority) = uri.authority()
1838    {
1839        let host = authority.host().to_owned();
1840        if !hosts.iter().any(|h| h == &host) {
1841            hosts.push(host);
1842        }
1843
1844        let authority = authority.as_str().to_owned();
1845        if !hosts.iter().any(|h| h == &authority) {
1846            hosts.push(authority);
1847        }
1848    }
1849
1850    hosts
1851}
1852
1853// - TLS support -
1854
1855/// Implement axum's `Connected` trait for `TlsConnInfo` so that
1856/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
1857/// over our custom `TlsListener`.
1858///
1859/// The identity is read directly from the wrapping
1860/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
1861/// between the TLS connection and its mTLS identity. This eliminates the
1862/// previous shared-map approach which was vulnerable to ephemeral-port
1863/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
1864/// pair could alias a stale entry).
1865impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1866    for TlsConnInfo
1867{
1868    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1869        let addr = *target.remote_addr();
1870        let identity = target.io().identity().cloned();
1871        TlsConnInfo::new(addr, identity)
1872    }
1873}
1874
1875/// Maximum time a single TLS handshake may take before the connection is
1876/// dropped. Prevents idle or slow-loris connections from pinning handshake
1877/// worker tasks (and their semaphore permits) indefinitely.
1878///
1879/// Tunability is deliberately deferred; tracked in
1880/// <https://github.com/andrico21/rmcp-server-kit/issues/9>.
1881const TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
1882
1883/// Upper bound on concurrently in-flight TLS handshakes. When saturated,
1884/// the acceptor task stops pulling new connections from the kernel backlog
1885/// (backpressure) instead of accepting and dropping them in user space.
1886///
1887/// Tunability is deliberately deferred; tracked in
1888/// <https://github.com/andrico21/rmcp-server-kit/issues/9>.
1889const MAX_INFLIGHT_TLS_HANDSHAKES: usize = 256;
1890
1891/// Capacity of the completed-handshake queue between the acceptor task and
1892/// `axum::serve`'s `accept()` loop. Handshake workers block on `send` when
1893/// the queue is full, so a slow accept loop back-pressures handshakes
1894/// rather than buffering completed connections unboundedly.
1895const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
1896
1897/// A TLS-wrapping listener that implements axum's `Listener` trait.
1898///
1899/// TCP accepts and TLS handshakes run on a dedicated background task: each
1900/// accepted connection's handshake is spawned onto its own worker task,
1901/// bounded by [`MAX_INFLIGHT_TLS_HANDSHAKES`] concurrent handshakes and a
1902/// per-handshake [`TLS_HANDSHAKE_TIMEOUT`]. A slow or idle client therefore
1903/// cannot stall other connections behind a serialized inline handshake.
1904///
1905/// When mTLS is configured, client certificates are verified against the
1906/// configured CA and the client identity is extracted at handshake time.
1907/// The extracted identity is bound to the connection itself via the
1908/// returned [`AuthenticatedTlsStream`], so it is impossible for an
1909/// unrelated connection to observe it.
1910struct TlsListener {
1911    /// Bound address, captured eagerly before the `TcpListener` moves into
1912    /// the acceptor task.
1913    local_addr: SocketAddr,
1914    /// Completed handshakes produced by the acceptor task's workers.
1915    rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
1916    /// Background task driving TCP accepts and concurrent TLS handshakes.
1917    /// Aborted on drop so the listener releases its port deterministically.
1918    acceptor_task: tokio::task::JoinHandle<()>,
1919}
1920
1921impl TlsListener {
1922    fn new(
1923        inner: TcpListener,
1924        cert_path: &Path,
1925        key_path: &Path,
1926        mtls_config: Option<&MtlsConfig>,
1927        crl_set: Option<Arc<CrlSet>>,
1928        handshake_timeout: Duration,
1929    ) -> anyhow::Result<Self> {
1930        // Install the ring crypto provider (ok to call multiple times).
1931        rustls::crypto::ring::default_provider()
1932            .install_default()
1933            .ok();
1934
1935        let certs = load_certs(cert_path)?;
1936        let key = load_key(key_path)?;
1937
1938        let mtls_default_role;
1939
1940        let tls_config = if let Some(mtls) = mtls_config {
1941            mtls_default_role = mtls.default_role.clone();
1942            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
1943            {
1944                let Some(crl_set) = crl_set else {
1945                    return Err(anyhow::anyhow!(
1946                        "mTLS CRL verifier requested but CRL state was not initialized"
1947                    ));
1948                };
1949                Arc::new(DynamicClientCertVerifier::new(crl_set))
1950            } else {
1951                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
1952                if mtls.required {
1953                    rustls::server::WebPkiClientVerifier::builder(root_store)
1954                        .build()
1955                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1956                } else {
1957                    rustls::server::WebPkiClientVerifier::builder(root_store)
1958                        .allow_unauthenticated()
1959                        .build()
1960                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
1961                }
1962            };
1963
1964            tracing::info!(
1965                ca = %mtls.ca_cert_path.display(),
1966                required = mtls.required,
1967                crl_enabled = mtls.crl_enabled,
1968                "mTLS client auth configured"
1969            );
1970
1971            rustls::ServerConfig::builder_with_protocol_versions(&[
1972                &rustls::version::TLS12,
1973                &rustls::version::TLS13,
1974            ])
1975            .with_client_cert_verifier(verifier)
1976            .with_single_cert(certs, key)?
1977        } else {
1978            mtls_default_role = "viewer".to_owned();
1979            rustls::ServerConfig::builder_with_protocol_versions(&[
1980                &rustls::version::TLS12,
1981                &rustls::version::TLS13,
1982            ])
1983            .with_no_client_auth()
1984            .with_single_cert(certs, key)?
1985        };
1986
1987        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
1988        tracing::info!(
1989            "TLS enabled (cert: {}, key: {})",
1990            cert_path.display(),
1991            key_path.display()
1992        );
1993        let local_addr = inner.local_addr()?;
1994        let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
1995        let acceptor_task = tokio::spawn(run_tls_acceptor(
1996            inner,
1997            acceptor,
1998            mtls_default_role,
1999            tx,
2000            handshake_timeout,
2001        ));
2002        Ok(Self {
2003            local_addr,
2004            rx,
2005            acceptor_task,
2006        })
2007    }
2008
2009    /// Extract the mTLS client cert identity from a completed TLS handshake.
2010    /// Returns `None` if no client certificate was presented or if the
2011    /// certificate could not be parsed into an [`AuthIdentity`].
2012    fn extract_handshake_identity(
2013        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2014        default_role: &str,
2015        addr: SocketAddr,
2016    ) -> Option<AuthIdentity> {
2017        let (_, server_conn) = tls_stream.get_ref();
2018        let cert_der = server_conn.peer_certificates()?.first()?;
2019        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2020        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2021        Some(id)
2022    }
2023}
2024
2025/// Drive TCP accepts and concurrent TLS handshakes for [`TlsListener`].
2026///
2027/// Each accepted connection's handshake runs on its own worker task under
2028/// a permit from a [`MAX_INFLIGHT_TLS_HANDSHAKES`]-sized semaphore and a
2029/// `handshake_timeout` deadline. Completed handshakes are pushed to `tx`;
2030/// failures and timeouts are logged at DEBUG and the connection dropped.
2031/// The loop exits when the owning [`TlsListener`] is dropped.
2032async fn run_tls_acceptor(
2033    listener: TcpListener,
2034    acceptor: tokio_rustls::TlsAcceptor,
2035    default_role: String,
2036    tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2037    handshake_timeout: Duration,
2038) {
2039    let inflight = Arc::new(Semaphore::new(MAX_INFLIGHT_TLS_HANDSHAKES));
2040    loop {
2041        // Acquire the permit BEFORE accepting: at saturation, pending
2042        // connections wait in the kernel backlog instead of being accepted
2043        // and then buffered or dropped in user space.
2044        let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2045            // The semaphore is never closed; defensive exit.
2046            return;
2047        };
2048        let (stream, addr) = match listener.accept().await {
2049            Ok(pair) => pair,
2050            Err(e) => {
2051                tracing::debug!("TCP accept error: {e}");
2052                continue;
2053            }
2054        };
2055        if tx.is_closed() {
2056            // The listener was dropped (shutdown): stop accepting.
2057            return;
2058        }
2059        let acceptor = acceptor.clone();
2060        let default_role = default_role.clone();
2061        let tx = tx.clone();
2062        tokio::spawn(async move {
2063            let _permit = permit;
2064            match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2065                Ok(Ok(tls_stream)) => {
2066                    let identity =
2067                        TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2068                    let wrapped = AuthenticatedTlsStream {
2069                        inner: tls_stream,
2070                        identity,
2071                    };
2072                    // The receiver only disappears during shutdown; discard
2073                    // the completed connection quietly rather than logging.
2074                    let _ = tx.send((wrapped, addr)).await;
2075                }
2076                Ok(Err(e)) => {
2077                    tracing::debug!("TLS handshake failed from {addr}: {e}");
2078                }
2079                Err(_elapsed) => {
2080                    tracing::debug!(
2081                        "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2082                    );
2083                }
2084            }
2085        });
2086    }
2087}
2088
2089/// A TLS stream paired with the mTLS identity extracted at handshake time.
2090///
2091/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
2092/// identity travels with the connection itself. This replaces the previous
2093/// shared `MtlsIdentities` map, eliminating the
2094/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
2095/// reuse and removing the need for an LRU eviction policy.
2096///
2097/// The wrapper is `Unpin` (its inner stream is `Unpin` because
2098/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
2099/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
2100pub(crate) struct AuthenticatedTlsStream {
2101    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2102    identity: Option<AuthIdentity>,
2103}
2104
2105impl AuthenticatedTlsStream {
2106    /// Returns the verified mTLS client identity, if any.
2107    #[must_use]
2108    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2109        self.identity.as_ref()
2110    }
2111}
2112
2113impl std::fmt::Debug for AuthenticatedTlsStream {
2114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2115        f.debug_struct("AuthenticatedTlsStream")
2116            .field("identity", &self.identity.as_ref().map(|id| &id.name))
2117            .finish_non_exhaustive()
2118    }
2119}
2120
2121impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2122    fn poll_read(
2123        mut self: Pin<&mut Self>,
2124        cx: &mut std::task::Context<'_>,
2125        buf: &mut tokio::io::ReadBuf<'_>,
2126    ) -> std::task::Poll<std::io::Result<()>> {
2127        Pin::new(&mut self.inner).poll_read(cx, buf)
2128    }
2129}
2130
2131impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2132    fn poll_write(
2133        mut self: Pin<&mut Self>,
2134        cx: &mut std::task::Context<'_>,
2135        buf: &[u8],
2136    ) -> std::task::Poll<std::io::Result<usize>> {
2137        Pin::new(&mut self.inner).poll_write(cx, buf)
2138    }
2139
2140    fn poll_flush(
2141        mut self: Pin<&mut Self>,
2142        cx: &mut std::task::Context<'_>,
2143    ) -> std::task::Poll<std::io::Result<()>> {
2144        Pin::new(&mut self.inner).poll_flush(cx)
2145    }
2146
2147    fn poll_shutdown(
2148        mut self: Pin<&mut Self>,
2149        cx: &mut std::task::Context<'_>,
2150    ) -> std::task::Poll<std::io::Result<()>> {
2151        Pin::new(&mut self.inner).poll_shutdown(cx)
2152    }
2153
2154    fn poll_write_vectored(
2155        mut self: Pin<&mut Self>,
2156        cx: &mut std::task::Context<'_>,
2157        bufs: &[std::io::IoSlice<'_>],
2158    ) -> std::task::Poll<std::io::Result<usize>> {
2159        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2160    }
2161
2162    fn is_write_vectored(&self) -> bool {
2163        self.inner.is_write_vectored()
2164    }
2165}
2166
2167impl axum::serve::Listener for TlsListener {
2168    type Io = AuthenticatedTlsStream;
2169    type Addr = SocketAddr;
2170
2171    /// Yield the next fully-handshaken TLS connection.
2172    ///
2173    /// Cancel-safe: this is a plain `mpsc::Receiver::recv`, so cancelling
2174    /// the future (axum selects it against graceful shutdown) never loses
2175    /// a connection.
2176    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2177        if let Some(pair) = self.rx.recv().await {
2178            return pair;
2179        }
2180        // The channel only closes if the acceptor task terminated, which
2181        // means the TcpListener is gone and the OS already refuses new
2182        // connections. `Listener::accept` is infallible and panicking is
2183        // forbidden, so park forever: existing connections keep being
2184        // served and graceful shutdown still completes.
2185        tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2186        std::future::pending().await
2187    }
2188
2189    fn local_addr(&self) -> std::io::Result<Self::Addr> {
2190        Ok(self.local_addr)
2191    }
2192}
2193
2194impl Drop for TlsListener {
2195    fn drop(&mut self) {
2196        // Stop accepting immediately and release the bound port. In-flight
2197        // handshake workers notice the closed channel and exit quietly.
2198        self.acceptor_task.abort();
2199    }
2200}
2201
2202fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2203    use rustls::pki_types::pem::PemObject;
2204    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2205        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2206        .collect::<Result<_, _>>()
2207        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2208    anyhow::ensure!(
2209        !certs.is_empty(),
2210        "no certificates found in {}",
2211        path.display()
2212    );
2213    Ok(certs)
2214}
2215
2216fn load_client_auth_roots(
2217    path: &Path,
2218) -> anyhow::Result<(
2219    Vec<rustls::pki_types::CertificateDer<'static>>,
2220    Arc<RootCertStore>,
2221)> {
2222    let ca_certs = load_certs(path)?;
2223    let mut root_store = RootCertStore::empty();
2224    for cert in &ca_certs {
2225        root_store
2226            .add(cert.clone())
2227            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2228    }
2229
2230    Ok((ca_certs, Arc::new(root_store)))
2231}
2232
2233fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2234    use rustls::pki_types::pem::PemObject;
2235    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2236        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2237}
2238
2239#[allow(
2240    clippy::unused_async,
2241    reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2242)]
2243async fn healthz() -> impl IntoResponse {
2244    axum::Json(serde_json::json!({
2245        "status": "ok",
2246    }))
2247}
2248
2249/// Build the `/version` JSON payload for a given server name and version.
2250///
2251/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
2252/// read at compile time from the `RMCP_SERVER_KIT_BUILD_SHA`,
2253/// `RMCP_SERVER_KIT_BUILD_TIME`, and `RMCP_SERVER_KIT_RUSTC_VERSION` env
2254/// vars. Unset values resolve to `"unknown"`.
2255fn version_payload(name: &str, version: &str) -> serde_json::Value {
2256    serde_json::json!({
2257        "name": name,
2258        "version": version,
2259        "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2260        "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2261        "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2262        "mcpx_version": env!("CARGO_PKG_VERSION"),
2263    })
2264}
2265
2266/// Pre-serialize the `/version` payload to immutable bytes.
2267///
2268/// This is called once at router-build time so per-request handling can
2269/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
2270/// [`serde_json::Value`] on every hit.
2271///
2272/// Serialization of a flat `serde_json::Value` of static-string fields
2273/// cannot fail in practice; the fallback to `b"{}"` exists only to
2274/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
2275fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2276    let value = version_payload(name, version);
2277    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2278}
2279
2280async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2281    let status = check().await;
2282    let ready = status
2283        .get("ready")
2284        .and_then(serde_json::Value::as_bool)
2285        .unwrap_or(false);
2286    let code = if ready {
2287        axum::http::StatusCode::OK
2288    } else {
2289        axum::http::StatusCode::SERVICE_UNAVAILABLE
2290    };
2291    (code, axum::Json(status))
2292}
2293
2294/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2295///
2296/// On non-Unix platforms, only SIGINT is handled.
2297async fn shutdown_signal() {
2298    let ctrl_c = tokio::signal::ctrl_c();
2299
2300    #[cfg(unix)]
2301    {
2302        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2303            Ok(mut term) => {
2304                tokio::select! {
2305                    _ = ctrl_c => {}
2306                    _ = term.recv() => {}
2307                }
2308            }
2309            Err(e) => {
2310                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2311                ctrl_c.await.ok();
2312            }
2313        }
2314    }
2315
2316    #[cfg(not(unix))]
2317    {
2318        ctrl_c.await.ok();
2319    }
2320}
2321
2322// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2323
2324/// Middleware that validates the `Origin` header on incoming HTTP requests.
2325///
2326/// Record HTTP request metrics (method, path, status, duration).
2327#[cfg(feature = "metrics")]
2328async fn metrics_middleware(
2329    metrics: Arc<crate::metrics::McpMetrics>,
2330    req: Request<Body>,
2331    next: Next,
2332) -> axum::response::Response {
2333    let method = req.method().to_string();
2334    let path = req.uri().path().to_owned();
2335    let start = std::time::Instant::now();
2336
2337    let response = next.run(req).await;
2338
2339    let status = response.status().as_u16().to_string();
2340    let duration = start.elapsed().as_secs_f64();
2341
2342    metrics
2343        .http_requests_total
2344        .with_label_values(&[&method, &path, &status])
2345        .inc();
2346    metrics
2347        .http_request_duration_seconds
2348        .with_label_values(&[&method, &path])
2349        .observe(duration);
2350
2351    response
2352}
2353
2354/// OWASP security header hardening applied to every response.
2355///
2356/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2357/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2358/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2359/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2360/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2361///
2362/// Each header's value can be customised via [`SecurityHeadersConfig`]
2363/// on [`McpServerConfig`]. See that type for the three-state semantic
2364/// (`None` = default, `Some("")` = omit, `Some(v)` = override).
2365async fn security_headers_middleware(
2366    is_tls: bool,
2367    cfg: Arc<SecurityHeadersConfig>,
2368    req: Request<Body>,
2369    next: Next,
2370) -> axum::response::Response {
2371    use axum::http::{HeaderName, header};
2372
2373    let mut resp = next.run(req).await;
2374    let headers = resp.headers_mut();
2375
2376    // Strip server identity headers to reduce information leakage.
2377    headers.remove(header::SERVER);
2378    headers.remove(HeaderName::from_static("x-powered-by"));
2379
2380    apply_security_header(
2381        headers,
2382        header::X_CONTENT_TYPE_OPTIONS,
2383        cfg.x_content_type_options.as_deref(),
2384        "nosniff",
2385    );
2386    apply_security_header(
2387        headers,
2388        header::X_FRAME_OPTIONS,
2389        cfg.x_frame_options.as_deref(),
2390        "deny",
2391    );
2392    apply_security_header(
2393        headers,
2394        header::CACHE_CONTROL,
2395        cfg.cache_control.as_deref(),
2396        "no-store, max-age=0",
2397    );
2398    apply_security_header(
2399        headers,
2400        header::REFERRER_POLICY,
2401        cfg.referrer_policy.as_deref(),
2402        "no-referrer",
2403    );
2404    apply_security_header(
2405        headers,
2406        HeaderName::from_static("cross-origin-opener-policy"),
2407        cfg.cross_origin_opener_policy.as_deref(),
2408        "same-origin",
2409    );
2410    apply_security_header(
2411        headers,
2412        HeaderName::from_static("cross-origin-resource-policy"),
2413        cfg.cross_origin_resource_policy.as_deref(),
2414        "same-origin",
2415    );
2416    apply_security_header(
2417        headers,
2418        HeaderName::from_static("cross-origin-embedder-policy"),
2419        cfg.cross_origin_embedder_policy.as_deref(),
2420        "require-corp",
2421    );
2422    apply_security_header(
2423        headers,
2424        HeaderName::from_static("permissions-policy"),
2425        cfg.permissions_policy.as_deref(),
2426        "accelerometer=(), camera=(), geolocation=(), microphone=()",
2427    );
2428    apply_security_header(
2429        headers,
2430        HeaderName::from_static("x-permitted-cross-domain-policies"),
2431        cfg.x_permitted_cross_domain_policies.as_deref(),
2432        "none",
2433    );
2434    apply_security_header(
2435        headers,
2436        HeaderName::from_static("content-security-policy"),
2437        cfg.content_security_policy.as_deref(),
2438        "default-src 'none'; frame-ancestors 'none'",
2439    );
2440    apply_security_header(
2441        headers,
2442        HeaderName::from_static("x-dns-prefetch-control"),
2443        cfg.x_dns_prefetch_control.as_deref(),
2444        "off",
2445    );
2446
2447    if is_tls {
2448        apply_security_header(
2449            headers,
2450            header::STRICT_TRANSPORT_SECURITY,
2451            cfg.strict_transport_security.as_deref(),
2452            "max-age=63072000; includeSubDomains",
2453        );
2454    }
2455
2456    resp
2457}
2458
2459/// Set a single security header on the response, honouring the
2460/// three-state override semantic (None = default, Some("") = omit,
2461/// Some(value) = override).
2462///
2463/// Defence-in-depth: if an override value somehow reaches this point
2464/// despite [`validate_security_headers`] having approved it (e.g. a
2465/// runtime mutation on a non-`Validated` field), we log at error level
2466/// and fall back to the static default rather than panicking. The
2467/// `Validated<McpServerConfig>` type makes that path unreachable in
2468/// well-typed code paths.
2469fn apply_security_header(
2470    headers: &mut axum::http::HeaderMap,
2471    name: axum::http::HeaderName,
2472    override_value: Option<&str>,
2473    default: &'static str,
2474) {
2475    use axum::http::HeaderValue;
2476
2477    match override_value {
2478        None => {
2479            headers.insert(name, HeaderValue::from_static(default));
2480        }
2481        Some("") => {
2482            // Operator explicitly opted out of this header.
2483        }
2484        Some(v) => match HeaderValue::from_str(v) {
2485            Ok(hv) => {
2486                headers.insert(name, hv);
2487            }
2488            Err(err) => {
2489                tracing::error!(
2490                    header = %name,
2491                    error = %err,
2492                    "invalid security header override reached middleware; using default"
2493                );
2494                headers.insert(name, HeaderValue::from_static(default));
2495            }
2496        },
2497    }
2498}
2499
2500/// Validate every non-empty entry in a [`SecurityHeadersConfig`].
2501///
2502/// - `None` and `Some("")` are accepted unconditionally (use-default and
2503///   omit, respectively).
2504/// - `Some(v)` is rejected if `axum::http::HeaderValue::from_str(v)` fails.
2505/// - `strict_transport_security` additionally rejects any value
2506///   containing `preload` (case-insensitive). Operators who genuinely
2507///   want to commit to the HSTS preload list must do so via a future
2508///   explicit `with_hsts_preload(true)` builder, not by smuggling
2509///   `preload` through this knob.
2510fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2511    use axum::http::HeaderValue;
2512
2513    let fields: &[(&str, Option<&str>)] = &[
2514        (
2515            "x_content_type_options",
2516            cfg.x_content_type_options.as_deref(),
2517        ),
2518        ("x_frame_options", cfg.x_frame_options.as_deref()),
2519        ("cache_control", cfg.cache_control.as_deref()),
2520        ("referrer_policy", cfg.referrer_policy.as_deref()),
2521        (
2522            "cross_origin_opener_policy",
2523            cfg.cross_origin_opener_policy.as_deref(),
2524        ),
2525        (
2526            "cross_origin_resource_policy",
2527            cfg.cross_origin_resource_policy.as_deref(),
2528        ),
2529        (
2530            "cross_origin_embedder_policy",
2531            cfg.cross_origin_embedder_policy.as_deref(),
2532        ),
2533        ("permissions_policy", cfg.permissions_policy.as_deref()),
2534        (
2535            "x_permitted_cross_domain_policies",
2536            cfg.x_permitted_cross_domain_policies.as_deref(),
2537        ),
2538        (
2539            "content_security_policy",
2540            cfg.content_security_policy.as_deref(),
2541        ),
2542        (
2543            "x_dns_prefetch_control",
2544            cfg.x_dns_prefetch_control.as_deref(),
2545        ),
2546        (
2547            "strict_transport_security",
2548            cfg.strict_transport_security.as_deref(),
2549        ),
2550    ];
2551
2552    for (field, value) in fields {
2553        let Some(v) = value else { continue };
2554        if v.is_empty() {
2555            continue;
2556        }
2557        if let Err(err) = HeaderValue::from_str(v) {
2558            return Err(McpxError::Config(format!(
2559                "invalid security_headers.{field}: {err}"
2560            )));
2561        }
2562    }
2563
2564    if let Some(v) = cfg.strict_transport_security.as_deref()
2565        && !v.is_empty()
2566        && v.to_ascii_lowercase().contains("preload")
2567    {
2568        return Err(McpxError::Config(format!(
2569            "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2570             HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2571        )));
2572    }
2573
2574    Ok(())
2575}
2576
2577/// Append RFC 6749 §5.1 / RFC 6750 §5.4 cache and `Vary` headers required
2578/// on OAuth token-issuing responses.
2579///
2580/// `Cache-Control: no-store, max-age=0` is already applied globally by
2581/// [`security_headers_middleware`]; this middleware adds:
2582///
2583/// - `Pragma: no-cache` -- mandated by RFC 6749 §5.1 for HTTP/1.0 caches.
2584/// - `Vary: Authorization` -- mandated by RFC 6750 §5.4 for endpoints
2585///   whose response depends on the `Authorization` header.
2586///
2587/// Applied only to the OAuth proxy token-class endpoints (`/token`,
2588/// `/register`, `/introspect`, `/revoke`). `Vary` is appended (not
2589/// inserted) so any `Vary` value already present (e.g. `Accept-Encoding`
2590/// from a compression layer, or `Origin` from a CORS layer) is preserved.
2591#[cfg(feature = "oauth")]
2592async fn oauth_token_cache_headers_middleware(
2593    req: Request<Body>,
2594    next: Next,
2595) -> axum::response::Response {
2596    use axum::http::{HeaderValue, header};
2597
2598    let mut resp = next.run(req).await;
2599    let headers = resp.headers_mut();
2600    headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2601    headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2602    resp
2603}
2604
2605/// Per the MCP spec: if the Origin header is present and its value is not in
2606/// the allowed list, respond with 403 Forbidden. Requests without an Origin
2607/// header are allowed through (e.g. non-browser clients like curl, SDKs).
2608async fn origin_check_middleware(
2609    allowed: Arc<[String]>,
2610    log_request_headers: bool,
2611    req: Request<Body>,
2612    next: Next,
2613) -> axum::response::Response {
2614    let method = req.method().clone();
2615    let path = req.uri().path().to_owned();
2616
2617    log_incoming_request(&method, &path, req.headers(), log_request_headers);
2618
2619    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2620        let origin_str = origin.to_str().unwrap_or("");
2621        if !allowed.iter().any(|a| a == origin_str) {
2622            tracing::warn!(
2623                origin = origin_str,
2624                %method,
2625                %path,
2626                allowed = ?&*allowed,
2627                "rejected request: Origin not allowed"
2628            );
2629            return (
2630                axum::http::StatusCode::FORBIDDEN,
2631                "Forbidden: Origin not allowed",
2632            )
2633                .into_response();
2634        }
2635    }
2636    next.run(req).await
2637}
2638
2639/// Emit a DEBUG log for an incoming request, optionally including the full
2640/// (redacted) header set.
2641fn log_incoming_request(
2642    method: &axum::http::Method,
2643    path: &str,
2644    headers: &axum::http::HeaderMap,
2645    log_request_headers: bool,
2646) {
2647    if log_request_headers {
2648        tracing::debug!(
2649            %method,
2650            %path,
2651            headers = %format_request_headers_for_log(headers),
2652            "incoming request"
2653        );
2654    } else {
2655        tracing::debug!(%method, %path, "incoming request");
2656    }
2657}
2658
2659fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2660    headers
2661        .iter()
2662        .map(|(k, v)| {
2663            let name = k.as_str();
2664            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2665                format!("{name}: [REDACTED]")
2666            } else {
2667                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2668            }
2669        })
2670        .collect::<Vec<_>>()
2671        .join(", ")
2672}
2673
2674// -- stdio transport --
2675
2676/// Serve an MCP server over stdin/stdout (stdio transport).
2677///
2678/// # Security warnings
2679///
2680/// - **No authentication**: the parent process has full, unrestricted access.
2681/// - **No RBAC**: all tools are available regardless of policy.
2682/// - **No TLS**: messages travel over OS pipes in plaintext.
2683/// - **Single client**: only the parent process can connect.
2684/// - **No Origin validation**: not applicable to stdio.
2685///
2686/// Use this only when the MCP client spawns the server as a trusted subprocess
2687/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
2688/// use `serve()` (Streamable HTTP) instead.
2689///
2690/// # Errors
2691///
2692/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
2693/// transport disconnects unexpectedly.
2694// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
2695// macro expansion in this 18-line function (info/warn/info + two matches).
2696// There is nothing meaningful to extract; the allow stays.
2697#[allow(
2698    clippy::cognitive_complexity,
2699    reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
2700)]
2701pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2702where
2703    H: ServerHandler + 'static,
2704{
2705    use rmcp::ServiceExt as _;
2706
2707    tracing::info!("stdio transport: serving on stdin/stdout");
2708    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2709
2710    let transport = rmcp::transport::io::stdio();
2711
2712    let service = handler
2713        .serve(transport)
2714        .await
2715        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2716
2717    if let Err(e) = service.waiting().await {
2718        tracing::warn!(error = %e, "stdio session ended with error");
2719    }
2720    tracing::info!("stdio session ended");
2721    Ok(())
2722}
2723
2724#[cfg(test)]
2725mod tests {
2726    #![allow(
2727        clippy::unwrap_used,
2728        clippy::expect_used,
2729        clippy::panic,
2730        clippy::indexing_slicing,
2731        clippy::unwrap_in_result,
2732        clippy::print_stdout,
2733        clippy::print_stderr,
2734        deprecated,
2735        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2736    )]
2737    use std::{sync::Arc, time::Duration};
2738
2739    use axum::{
2740        body::Body,
2741        http::{Request, StatusCode, header},
2742        response::IntoResponse,
2743    };
2744    use http_body_util::BodyExt;
2745    use tower::ServiceExt as _;
2746
2747    use super::*;
2748
2749    // -- McpServerConfig --
2750
2751    #[test]
2752    fn server_config_new_defaults() {
2753        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2754        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2755        assert_eq!(cfg.name, "test-server");
2756        assert_eq!(cfg.version, "1.0.0");
2757        assert!(cfg.tls_cert_path.is_none());
2758        assert!(cfg.tls_key_path.is_none());
2759        assert!(cfg.auth.is_none());
2760        assert!(cfg.rbac.is_none());
2761        assert!(cfg.allowed_origins.is_empty());
2762        assert!(cfg.tool_rate_limit.is_none());
2763        assert!(cfg.readiness_check.is_none());
2764        assert_eq!(cfg.max_request_body, 1024 * 1024);
2765        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2766        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2767        assert!(!cfg.log_request_headers);
2768    }
2769
2770    #[test]
2771    fn validate_consumes_and_proves() {
2772        // Valid config -> Validated wrapper, original is consumed.
2773        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2774        let validated = cfg.validate().expect("valid config");
2775        // as_inner() gives read-only access to inner fields.
2776        assert_eq!(validated.as_inner().name, "test-server");
2777        // into_inner recovers the raw value.
2778        let raw = validated.into_inner();
2779        assert_eq!(raw.name, "test-server");
2780
2781        // Invalid config (zero max_request_body) -> Err.
2782        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2783        bad.max_request_body = 0;
2784        assert!(bad.validate().is_err(), "zero body cap must fail validate");
2785    }
2786
2787    #[test]
2788    fn validate_rejects_zero_max_concurrent_requests() {
2789        let cfg =
2790            McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
2791        let err = cfg.validate().expect_err("zero concurrency cap must fail");
2792        assert!(
2793            format!("{err}").contains("max_concurrent_requests"),
2794            "error should mention max_concurrent_requests, got: {err}"
2795        );
2796    }
2797
2798    #[test]
2799    fn validate_rejects_zero_max_tracked_keys() {
2800        // Defaults mirror auth::default_max_attempts / default_idle_eviction
2801        // (module-private in auth.rs); spelled out here for review clarity.
2802        let rl = crate::auth::RateLimitConfig {
2803            max_attempts_per_minute: 30,
2804            pre_auth_max_per_minute: None,
2805            max_tracked_keys: 0,
2806            idle_eviction: Duration::from_secs(15 * 60),
2807        };
2808        let auth_cfg = AuthConfig {
2809            enabled: true,
2810            api_keys: Vec::new(),
2811            mtls: None,
2812            rate_limit: Some(rl),
2813            #[cfg(feature = "oauth")]
2814            oauth: None,
2815        };
2816        let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
2817        let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
2818        assert!(
2819            format!("{err}").contains("max_tracked_keys"),
2820            "error should mention max_tracked_keys, got: {err}"
2821        );
2822    }
2823
2824    #[test]
2825    fn derive_allowed_hosts_includes_public_host() {
2826        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2827        assert!(
2828            hosts.iter().any(|h| h == "mcp.example.com"),
2829            "public_url host must be allowed"
2830        );
2831    }
2832
2833    #[test]
2834    fn derive_allowed_hosts_includes_bind_authority() {
2835        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2836        assert!(
2837            hosts.iter().any(|h| h == "127.0.0.1"),
2838            "bind host must be allowed"
2839        );
2840        assert!(
2841            hosts.iter().any(|h| h == "127.0.0.1:8080"),
2842            "bind authority must be allowed"
2843        );
2844    }
2845
2846    // -- healthz --
2847
2848    #[tokio::test]
2849    async fn healthz_returns_ok_json() {
2850        let resp = healthz().await.into_response();
2851        assert_eq!(resp.status(), StatusCode::OK);
2852        let body = resp.into_body().collect().await.unwrap().to_bytes();
2853        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2854        assert_eq!(json["status"], "ok");
2855        assert!(
2856            json.get("name").is_none(),
2857            "healthz must not expose server name"
2858        );
2859        assert!(
2860            json.get("version").is_none(),
2861            "healthz must not expose version"
2862        );
2863    }
2864
2865    // -- readyz --
2866
2867    #[tokio::test]
2868    async fn readyz_returns_ok_when_ready() {
2869        let check: ReadinessCheck =
2870            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2871        let resp = readyz(check).await.into_response();
2872        assert_eq!(resp.status(), StatusCode::OK);
2873        let body = resp.into_body().collect().await.unwrap().to_bytes();
2874        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2875        assert_eq!(json["ready"], true);
2876        assert!(
2877            json.get("name").is_none(),
2878            "readyz must not expose server name"
2879        );
2880        assert!(
2881            json.get("version").is_none(),
2882            "readyz must not expose version"
2883        );
2884        assert_eq!(json["db"], "connected");
2885    }
2886
2887    #[tokio::test]
2888    async fn readyz_returns_503_when_not_ready() {
2889        let check: ReadinessCheck =
2890            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
2891        let resp = readyz(check).await.into_response();
2892        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2893    }
2894
2895    #[tokio::test]
2896    async fn readyz_returns_503_when_ready_missing() {
2897        let check: ReadinessCheck =
2898            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
2899        let resp = readyz(check).await.into_response();
2900        // Missing "ready" field defaults to false -> 503
2901        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
2902    }
2903
2904    // -- origin_check_middleware --
2905
2906    /// Build a test router with origin check middleware and a simple handler.
2907    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
2908        let allowed: Arc<[String]> = Arc::from(origins);
2909        axum::Router::new()
2910            .route("/test", axum::routing::get(|| async { "ok" }))
2911            .layer(axum::middleware::from_fn(move |req, next| {
2912                let a = Arc::clone(&allowed);
2913                origin_check_middleware(a, log_request_headers, req, next)
2914            }))
2915    }
2916
2917    #[tokio::test]
2918    async fn origin_allowed_passes() {
2919        let app = origin_router(vec!["http://localhost:3000".into()], false);
2920        let req = Request::builder()
2921            .uri("/test")
2922            .header(header::ORIGIN, "http://localhost:3000")
2923            .body(Body::empty())
2924            .unwrap();
2925        let resp = app.oneshot(req).await.unwrap();
2926        assert_eq!(resp.status(), StatusCode::OK);
2927    }
2928
2929    #[tokio::test]
2930    async fn origin_rejected_returns_403() {
2931        let app = origin_router(vec!["http://localhost:3000".into()], false);
2932        let req = Request::builder()
2933            .uri("/test")
2934            .header(header::ORIGIN, "http://evil.com")
2935            .body(Body::empty())
2936            .unwrap();
2937        let resp = app.oneshot(req).await.unwrap();
2938        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2939    }
2940
2941    #[tokio::test]
2942    async fn no_origin_header_passes() {
2943        let app = origin_router(vec!["http://localhost:3000".into()], false);
2944        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2945        let resp = app.oneshot(req).await.unwrap();
2946        assert_eq!(resp.status(), StatusCode::OK);
2947    }
2948
2949    #[tokio::test]
2950    async fn empty_allowlist_rejects_any_origin() {
2951        let app = origin_router(vec![], false);
2952        let req = Request::builder()
2953            .uri("/test")
2954            .header(header::ORIGIN, "http://anything.com")
2955            .body(Body::empty())
2956            .unwrap();
2957        let resp = app.oneshot(req).await.unwrap();
2958        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
2959    }
2960
2961    #[tokio::test]
2962    async fn empty_allowlist_passes_without_origin() {
2963        let app = origin_router(vec![], false);
2964        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
2965        let resp = app.oneshot(req).await.unwrap();
2966        assert_eq!(resp.status(), StatusCode::OK);
2967    }
2968
2969    #[test]
2970    fn format_request_headers_redacts_sensitive_values() {
2971        let mut headers = axum::http::HeaderMap::new();
2972        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
2973        headers.insert("cookie", "sid=abc".parse().unwrap());
2974        headers.insert("x-request-id", "req-123".parse().unwrap());
2975
2976        let out = format_request_headers_for_log(&headers);
2977        assert!(out.contains("authorization: [REDACTED]"));
2978        assert!(out.contains("cookie: [REDACTED]"));
2979        assert!(out.contains("x-request-id: req-123"));
2980        assert!(!out.contains("secret-token"));
2981    }
2982
2983    // -- security_headers_middleware --
2984
2985    fn security_router(is_tls: bool) -> axum::Router {
2986        security_router_with(is_tls, SecurityHeadersConfig::default())
2987    }
2988
2989    fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
2990        let cfg = Arc::new(cfg);
2991        axum::Router::new()
2992            .route("/test", axum::routing::get(|| async { "ok" }))
2993            .layer(axum::middleware::from_fn(move |req, next| {
2994                let c = Arc::clone(&cfg);
2995                security_headers_middleware(is_tls, c, req, next)
2996            }))
2997    }
2998
2999    #[tokio::test]
3000    async fn security_headers_set_on_response() {
3001        let app = security_router(false);
3002        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3003        let resp = app.oneshot(req).await.unwrap();
3004        assert_eq!(resp.status(), StatusCode::OK);
3005
3006        let h = resp.headers();
3007        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
3008        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
3009        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
3010        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
3011        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
3012        assert_eq!(
3013            h.get("cross-origin-resource-policy").unwrap(),
3014            "same-origin"
3015        );
3016        assert_eq!(
3017            h.get("cross-origin-embedder-policy").unwrap(),
3018            "require-corp"
3019        );
3020        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
3021        assert!(
3022            h.get("permissions-policy")
3023                .unwrap()
3024                .to_str()
3025                .unwrap()
3026                .contains("camera=()"),
3027            "permissions-policy must restrict browser features"
3028        );
3029        assert_eq!(
3030            h.get("content-security-policy").unwrap(),
3031            "default-src 'none'; frame-ancestors 'none'"
3032        );
3033        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
3034        // No HSTS when TLS is off.
3035        assert!(h.get("strict-transport-security").is_none());
3036    }
3037
3038    #[tokio::test]
3039    async fn hsts_set_when_tls_enabled() {
3040        let app = security_router(true);
3041        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3042        let resp = app.oneshot(req).await.unwrap();
3043
3044        let hsts = resp.headers().get("strict-transport-security").unwrap();
3045        assert!(
3046            hsts.to_str().unwrap().contains("max-age=63072000"),
3047            "HSTS must set 2-year max-age"
3048        );
3049    }
3050
3051    // -- SecurityHeadersConfig validation + override semantics --
3052
3053    /// Build a minimal config with a custom SecurityHeadersConfig and
3054    /// drive it through `check()`. Returns the result so individual
3055    /// tests can assert on success or specific error messages.
3056    fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
3057        let cfg =
3058            McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
3059        cfg.check()
3060    }
3061
3062    #[test]
3063    fn security_headers_config_default_validates() {
3064        check_with_security_headers(SecurityHeadersConfig::default())
3065            .expect("default SecurityHeadersConfig must validate");
3066    }
3067
3068    #[test]
3069    fn security_headers_config_validate_accepts_empty_string() {
3070        // All twelve fields explicitly set to "" -> omit-everything mode.
3071        let h = SecurityHeadersConfig {
3072            x_content_type_options: Some(String::new()),
3073            x_frame_options: Some(String::new()),
3074            cache_control: Some(String::new()),
3075            referrer_policy: Some(String::new()),
3076            cross_origin_opener_policy: Some(String::new()),
3077            cross_origin_resource_policy: Some(String::new()),
3078            cross_origin_embedder_policy: Some(String::new()),
3079            permissions_policy: Some(String::new()),
3080            x_permitted_cross_domain_policies: Some(String::new()),
3081            content_security_policy: Some(String::new()),
3082            x_dns_prefetch_control: Some(String::new()),
3083            strict_transport_security: Some(String::new()),
3084        };
3085        check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
3086    }
3087
3088    #[test]
3089    fn security_headers_config_validate_rejects_bad_value() {
3090        // 0x07 (BEL) is not a valid HTTP header value char.
3091        let h = SecurityHeadersConfig {
3092            referrer_policy: Some("\u{0007}".into()),
3093            ..SecurityHeadersConfig::default()
3094        };
3095        let err = check_with_security_headers(h)
3096            .expect_err("control char in referrer_policy must reject");
3097        let msg = err.to_string();
3098        assert!(
3099            msg.contains("referrer_policy"),
3100            "error must name the offending field, got: {msg}"
3101        );
3102    }
3103
3104    #[test]
3105    fn security_headers_config_validate_rejects_hsts_preload() {
3106        let h = SecurityHeadersConfig {
3107            strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
3108            ..SecurityHeadersConfig::default()
3109        };
3110        let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
3111        let msg = err.to_string();
3112        assert!(
3113            msg.contains("strict_transport_security"),
3114            "error must name the field, got: {msg}"
3115        );
3116        assert!(
3117            msg.to_lowercase().contains("preload"),
3118            "error must mention `preload`, got: {msg}"
3119        );
3120    }
3121
3122    #[test]
3123    fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
3124        // Case-insensitive match.
3125        let h = SecurityHeadersConfig {
3126            strict_transport_security: Some("max-age=600; PRELOAD".into()),
3127            ..SecurityHeadersConfig::default()
3128        };
3129        check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3130    }
3131
3132    #[tokio::test]
3133    async fn security_headers_override_honored() {
3134        // Override X-Frame-Options to SAMEORIGIN.
3135        let h = SecurityHeadersConfig {
3136            x_frame_options: Some("SAMEORIGIN".into()),
3137            ..SecurityHeadersConfig::default()
3138        };
3139        let app = security_router_with(false, h);
3140        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3141        let resp = app.oneshot(req).await.unwrap();
3142        assert_eq!(resp.status(), StatusCode::OK);
3143
3144        let xfo = resp.headers().get("x-frame-options").unwrap();
3145        assert_eq!(xfo, "SAMEORIGIN");
3146    }
3147
3148    #[tokio::test]
3149    async fn security_headers_empty_string_omits() {
3150        // Empty string on referrer-policy -> header absent.
3151        let h = SecurityHeadersConfig {
3152            referrer_policy: Some(String::new()),
3153            ..SecurityHeadersConfig::default()
3154        };
3155        let app = security_router_with(false, h);
3156        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3157        let resp = app.oneshot(req).await.unwrap();
3158        assert_eq!(resp.status(), StatusCode::OK);
3159
3160        assert!(
3161            resp.headers().get("referrer-policy").is_none(),
3162            "Some(\"\") must omit the header"
3163        );
3164        // Other defaults should still be present.
3165        assert_eq!(
3166            resp.headers().get("x-content-type-options").unwrap(),
3167            "nosniff"
3168        );
3169    }
3170
3171    #[tokio::test]
3172    async fn security_headers_hsts_only_when_tls() {
3173        // HSTS override is irrelevant when TLS is off.
3174        let h = SecurityHeadersConfig {
3175            strict_transport_security: Some("max-age=600".into()),
3176            ..SecurityHeadersConfig::default()
3177        };
3178        let app = security_router_with(false, h);
3179        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3180        let resp = app.oneshot(req).await.unwrap();
3181        assert!(
3182            resp.headers().get("strict-transport-security").is_none(),
3183            "HSTS must remain absent on plaintext deployments even with override"
3184        );
3185    }
3186
3187    // -- oauth_token_cache_headers_middleware --
3188
3189    #[cfg(feature = "oauth")]
3190    #[tokio::test]
3191    async fn oauth_token_cache_headers_set_pragma_and_vary() {
3192        let app = axum::Router::new()
3193            .route("/token", axum::routing::post(|| async { "{}" }))
3194            .layer(axum::middleware::from_fn(
3195                oauth_token_cache_headers_middleware,
3196            ));
3197        let req = Request::builder()
3198            .method("POST")
3199            .uri("/token")
3200            .body(Body::from("{}"))
3201            .unwrap();
3202        let resp = app.oneshot(req).await.unwrap();
3203        assert_eq!(resp.status(), StatusCode::OK);
3204
3205        let h = resp.headers();
3206        assert_eq!(
3207            h.get("pragma").unwrap(),
3208            "no-cache",
3209            "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3210        );
3211        let vary_values: Vec<String> = h
3212            .get_all("vary")
3213            .iter()
3214            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3215            .collect();
3216        assert!(
3217            vary_values
3218                .iter()
3219                .any(|v| v.eq_ignore_ascii_case("Authorization")),
3220            "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3221        );
3222    }
3223
3224    #[cfg(feature = "oauth")]
3225    #[tokio::test]
3226    async fn oauth_token_cache_headers_preserve_existing_vary() {
3227        // Simulates a handler/layer that already set `Vary: Accept-Encoding`
3228        // (e.g. compression). Our middleware must APPEND, not REPLACE.
3229        let app = axum::Router::new()
3230            .route(
3231                "/token",
3232                axum::routing::post(|| async {
3233                    axum::response::Response::builder()
3234                        .header("vary", "Accept-Encoding")
3235                        .body(axum::body::Body::from("{}"))
3236                        .unwrap()
3237                }),
3238            )
3239            .layer(axum::middleware::from_fn(
3240                oauth_token_cache_headers_middleware,
3241            ));
3242        let req = Request::builder()
3243            .method("POST")
3244            .uri("/token")
3245            .body(Body::empty())
3246            .unwrap();
3247        let resp = app.oneshot(req).await.unwrap();
3248
3249        let vary: Vec<String> = resp
3250            .headers()
3251            .get_all("vary")
3252            .iter()
3253            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3254            .collect();
3255        assert!(
3256            vary.iter().any(|v| v.contains("Accept-Encoding")),
3257            "must preserve pre-existing Vary value, got {vary:?}"
3258        );
3259        assert!(
3260            vary.iter().any(|v| v.contains("Authorization")),
3261            "must append Authorization to Vary, got {vary:?}"
3262        );
3263    }
3264
3265    // -- version endpoint --
3266
3267    #[test]
3268    fn version_payload_contains_expected_fields() {
3269        let v = version_payload("my-server", "1.2.3");
3270        assert_eq!(v["name"], "my-server");
3271        assert_eq!(v["version"], "1.2.3");
3272        assert!(v["build_git_sha"].is_string());
3273        assert!(v["build_timestamp"].is_string());
3274        assert!(v["rust_version"].is_string());
3275        assert!(v["mcpx_version"].is_string());
3276    }
3277
3278    // -- concurrency limit layer --
3279
3280    #[tokio::test]
3281    async fn concurrency_limit_layer_composes_and_serves() {
3282        // We only assert the layer stack compiles and a single request
3283        // below the cap still succeeds. True back-pressure behaviour
3284        // requires a live HTTP server and is covered by integration tests.
3285        let app = axum::Router::new()
3286            .route("/ok", axum::routing::get(|| async { "ok" }))
3287            .layer(
3288                tower::ServiceBuilder::new()
3289                    .layer(axum::error_handling::HandleErrorLayer::new(
3290                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3291                    ))
3292                    .layer(tower::load_shed::LoadShedLayer::new())
3293                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3294            );
3295        let resp = app
3296            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3297            .await
3298            .unwrap();
3299        assert_eq!(resp.status(), StatusCode::OK);
3300    }
3301
3302    // -- compression layer --
3303
3304    #[tokio::test]
3305    async fn compression_layer_gzip_encodes_response() {
3306        use tower_http::compression::Predicate as _;
3307
3308        let big_body = "a".repeat(4096);
3309        let app = axum::Router::new()
3310            .route(
3311                "/big",
3312                axum::routing::get(move || {
3313                    let body = big_body.clone();
3314                    async move { body }
3315                }),
3316            )
3317            .layer(
3318                tower_http::compression::CompressionLayer::new()
3319                    .gzip(true)
3320                    .br(true)
3321                    .compress_when(
3322                        tower_http::compression::DefaultPredicate::new()
3323                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3324                    ),
3325            );
3326
3327        let req = Request::builder()
3328            .uri("/big")
3329            .header(header::ACCEPT_ENCODING, "gzip")
3330            .body(Body::empty())
3331            .unwrap();
3332        let resp = app.oneshot(req).await.unwrap();
3333        assert_eq!(resp.status(), StatusCode::OK);
3334        assert_eq!(
3335            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3336            "gzip"
3337        );
3338    }
3339
3340    // -- TlsListener handshake timeout --
3341
3342    #[tokio::test]
3343    async fn tls_handshake_timeout_reaps_idle_connections() {
3344        use tokio::io::AsyncReadExt as _;
3345
3346        let _ = rustls::crypto::ring::default_provider().install_default();
3347
3348        // Self-signed cert material on disk (TlsListener::new takes paths).
3349        let key = rcgen::KeyPair::generate().expect("generate key");
3350        let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
3351            .expect("cert params")
3352            .self_signed(&key)
3353            .expect("self-signed cert");
3354        let dir = std::env::temp_dir().join(format!(
3355            "rmcp-server-kit-hs-timeout-{}",
3356            std::time::SystemTime::now()
3357                .duration_since(std::time::UNIX_EPOCH)
3358                .expect("clock after epoch")
3359                .as_nanos()
3360        ));
3361        tokio::fs::create_dir_all(&dir).await.expect("temp dir");
3362        let cert_path = dir.join("server.crt");
3363        let key_path = dir.join("server.key");
3364        tokio::fs::write(&cert_path, cert.pem())
3365            .await
3366            .expect("write cert");
3367        tokio::fs::write(&key_path, key.serialize_pem())
3368            .await
3369            .expect("write key");
3370
3371        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
3372        let tls = TlsListener::new(
3373            listener,
3374            &cert_path,
3375            &key_path,
3376            None,
3377            None,
3378            Duration::from_millis(200),
3379        )
3380        .expect("tls listener");
3381        let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
3382
3383        // Connect and send NOTHING: the handshake worker must time out
3384        // after 200ms and drop the stream, which the client observes as
3385        // EOF or a reset well within the 2s deadline.
3386        let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
3387        let mut buf = [0_u8; 16];
3388        let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
3389            .await
3390            .expect("server must reap the idle handshake within its timeout");
3391        match read {
3392            Ok(0) | Err(_) => {} // EOF or reset: connection was dropped.
3393            Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
3394        }
3395
3396        drop(tls);
3397    }
3398}