Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{
12    body::Body,
13    extract::{ConnectInfo, Request},
14    middleware::Next,
15    response::IntoResponse,
16};
17use rmcp::{
18    ServerHandler,
19    transport::streamable_http_server::{
20        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21    },
22};
23use rustls::RootCertStore;
24use tokio::{
25    net::TcpListener,
26    sync::{Semaphore, mpsc},
27};
28use tokio_util::sync::CancellationToken;
29
30use crate::{
31    auth::{
32        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
33        build_rate_limiter, extract_mtls_identity,
34    },
35    error::McpxError,
36    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
37    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
38};
39
40/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
41/// at the public API boundary, flattening the chain via the alternate
42/// formatter so callers see the full causal path.
43#[allow(
44    clippy::needless_pass_by_value,
45    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
46)]
47fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
48    McpxError::Startup(format!("{e:#}"))
49}
50
51/// Map a `std::io::Error` produced during server startup into a public
52/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
53/// `From` impl here because startup-phase IO errors (bind, listener) are
54/// semantically distinct from request-time IO errors and should surface
55/// the originating operation in the message.
56#[allow(
57    clippy::needless_pass_by_value,
58    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
59)]
60fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
61    McpxError::Startup(format!("{op}: {e}"))
62}
63
64/// Async readiness check callback for the `/readyz` endpoint.
65///
66/// Returns a JSON object with at least a `"ready"` boolean.
67/// When `ready` is false, the endpoint returns HTTP 503.
68pub type ReadinessCheck =
69    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
70
71/// Direct socket peer address of the current HTTP/TLS connection.
72///
73/// Inserted as a request extension into every request served by [`serve`] —
74/// on both the plain and the TLS listener — and extractable in any axum
75/// handler, including routes mounted via
76/// [`McpServerConfig::with_extra_router`] (which bypass auth/RBAC and
77/// therefore often need the peer address for their own protection, e.g.
78/// per-IP rate limiting).
79///
80/// The same address is also mirrored into
81/// [`axum::extract::ConnectInfo`]`<SocketAddr>` on the TLS listener, so
82/// third-party middleware that expects the stock axum extension (e.g.
83/// per-IP rate-limit key extractors) works unmodified under TLS.
84///
85/// # Semantics
86///
87/// - **Direct peer only.** This is the socket's remote address. Behind an
88///   L4/L7 proxy or load balancer it is the proxy's address; the framework
89///   performs **no** `X-Forwarded-For` / `Forwarded` interpretation.
90/// - **Available on HTTP and TLS** transports alike ([`serve`]).
91/// - **Absent under [`serve_stdio`]** — a stdio session has no network
92///   peer (stdio bypasses the HTTP stack entirely).
93/// - The separate Prometheus metrics listener (feature `metrics`) is a
94///   different router and does not carry this extension.
95///
96/// # Privacy
97///
98/// `PeerAddr` exposes raw peer network metadata. The framework deliberately
99/// never logs it on its own; whether to log or persist peer addresses is
100/// application policy.
101///
102/// # Example
103///
104/// ```no_run
105/// use axum::{Router, routing::get};
106/// use rmcp_server_kit::transport::{McpServerConfig, PeerAddr};
107///
108/// async fn whoami(peer: PeerAddr) -> String {
109///     peer.addr.ip().to_string()
110/// }
111///
112/// let _config = McpServerConfig::new("127.0.0.1:8443", "my-server", "1.0.0")
113///     .with_extra_router(Router::new().route("/whoami", get(whoami)));
114/// ```
115#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
116#[non_exhaustive]
117pub struct PeerAddr {
118    /// Direct socket peer of this connection.
119    pub addr: SocketAddr,
120}
121
122impl PeerAddr {
123    /// Construct a new [`PeerAddr`]. Framework-internal: downstream code
124    /// receives `PeerAddr` via request extensions and never constructs it.
125    #[must_use]
126    pub(crate) const fn new(addr: SocketAddr) -> Self {
127        Self { addr }
128    }
129}
130
131/// Extract [`PeerAddr`] from request extensions.
132///
133/// # Rejection
134///
135/// Responds `500 Internal Server Error` when the extension is missing.
136/// A missing `PeerAddr` means the handler is not running under [`serve`]
137/// (e.g. the router was mounted on a hand-rolled listener) — a wiring
138/// bug, not a client error.
139impl<S: Send + Sync> axum::extract::FromRequestParts<S> for PeerAddr {
140    type Rejection = (axum::http::StatusCode, &'static str);
141
142    async fn from_request_parts(
143        parts: &mut axum::http::request::Parts,
144        _state: &S,
145    ) -> Result<Self, Self::Rejection> {
146        parts.extensions.get::<Self>().copied().ok_or((
147            axum::http::StatusCode::INTERNAL_SERVER_ERROR,
148            "peer address unavailable: not running under rmcp-server-kit serve()",
149        ))
150    }
151}
152
153/// Per-header overrides for the OWASP security headers emitted by the
154/// global response middleware.
155///
156/// Each field follows a three-state semantic:
157///
158/// | Value         | Behaviour                                                |
159/// |---------------|----------------------------------------------------------|
160/// | `None`        | Use the built-in default (current behaviour).            |
161/// | `Some("")`    | **Omit** the header entirely from responses.             |
162/// | `Some(value)` | Emit `header: value`. Validated at config-load time.     |
163///
164/// All non-empty values are validated via
165/// [`axum::http::HeaderValue::from_str`] inside
166/// [`McpServerConfig::validate`]; invalid values fail fast before the
167/// server starts accepting traffic.
168///
169/// `Strict-Transport-Security` has an additional rule: the substring
170/// `preload` (case-insensitive) is rejected. Operators who want to
171/// commit to the HSTS preload list must do so via a future explicit
172/// builder method, not by smuggling it through this knob.
173#[derive(Debug, Clone, Default)]
174#[non_exhaustive]
175pub struct SecurityHeadersConfig {
176    /// Override for `X-Content-Type-Options`. Default: `nosniff`.
177    pub x_content_type_options: Option<String>,
178    /// Override for `X-Frame-Options`. Default: `deny`.
179    pub x_frame_options: Option<String>,
180    /// Override for `Cache-Control`. Default: `no-store, max-age=0`.
181    pub cache_control: Option<String>,
182    /// Override for `Referrer-Policy`. Default: `no-referrer`.
183    pub referrer_policy: Option<String>,
184    /// Override for `Cross-Origin-Opener-Policy`. Default: `same-origin`.
185    pub cross_origin_opener_policy: Option<String>,
186    /// Override for `Cross-Origin-Resource-Policy`. Default: `same-origin`.
187    pub cross_origin_resource_policy: Option<String>,
188    /// Override for `Cross-Origin-Embedder-Policy`. Default: `require-corp`.
189    pub cross_origin_embedder_policy: Option<String>,
190    /// Override for `Permissions-Policy`. Default:
191    /// `accelerometer=(), camera=(), geolocation=(), microphone=()`.
192    pub permissions_policy: Option<String>,
193    /// Override for `X-Permitted-Cross-Domain-Policies`. Default: `none`.
194    pub x_permitted_cross_domain_policies: Option<String>,
195    /// Override for `Content-Security-Policy`. Default:
196    /// `default-src 'none'; frame-ancestors 'none'`.
197    pub content_security_policy: Option<String>,
198    /// Override for `X-DNS-Prefetch-Control`. Default: `off`.
199    pub x_dns_prefetch_control: Option<String>,
200    /// Override for `Strict-Transport-Security`. Default (TLS only):
201    /// `max-age=63072000; includeSubDomains`. Only emitted when TLS is
202    /// active; the override is ignored on plaintext deployments. The
203    /// substring `preload` (any case) is rejected by the validator.
204    pub strict_transport_security: Option<String>,
205}
206
207/// Configuration for the MCP server.
208#[allow(
209    missing_debug_implementations,
210    reason = "contains callback/trait objects that don't impl Debug"
211)]
212#[allow(
213    clippy::struct_excessive_bools,
214    reason = "server configuration naturally has many boolean feature flags"
215)]
216#[non_exhaustive]
217pub struct McpServerConfig {
218    /// Socket address the MCP HTTP server binds to.
219    #[deprecated(
220        since = "0.13.0",
221        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
222    )]
223    pub bind_addr: String,
224    /// Server name advertised via MCP `initialize`.
225    #[deprecated(
226        since = "0.13.0",
227        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
228    )]
229    pub name: String,
230    /// Server version advertised via MCP `initialize`.
231    #[deprecated(
232        since = "0.13.0",
233        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
234    )]
235    pub version: String,
236    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
237    #[deprecated(
238        since = "0.13.0",
239        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
240    )]
241    pub tls_cert_path: Option<PathBuf>,
242    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
243    #[deprecated(
244        since = "0.13.0",
245        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
246    )]
247    pub tls_key_path: Option<PathBuf>,
248    /// Optional authentication config. When `Some` and `enabled`, auth
249    /// is enforced on `/mcp`. `/healthz` is always open.
250    #[deprecated(
251        since = "0.13.0",
252        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
253    )]
254    pub auth: Option<AuthConfig>,
255    /// Optional RBAC policy. When present and enabled, tool calls are
256    /// checked against the policy after authentication.
257    #[deprecated(
258        since = "0.13.0",
259        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
260    )]
261    pub rbac: Option<Arc<RbacPolicy>>,
262    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
263    /// When empty and `public_url` is set, the origin is auto-derived from
264    /// the public URL. When both are empty, only requests with no Origin
265    /// header are accepted.
266    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
267    #[deprecated(
268        since = "0.13.0",
269        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
270    )]
271    pub allowed_origins: Vec<String>,
272    /// Maximum tool invocations per source IP per minute.
273    /// When set, enforced on every `tools/call` request.
274    #[deprecated(
275        since = "0.13.0",
276        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
277    )]
278    pub tool_rate_limit: Option<u32>,
279    /// Optional readiness probe for `/readyz`.
280    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
281    #[deprecated(
282        since = "0.13.0",
283        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
284    )]
285    pub readiness_check: Option<ReadinessCheck>,
286    /// Maximum request body size in bytes. Default: 1 MiB.
287    /// Protects against oversized payloads causing OOM.
288    #[deprecated(
289        since = "0.13.0",
290        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
291    )]
292    pub max_request_body: usize,
293    /// Request processing timeout. Default: 120s.
294    /// Requests exceeding this duration receive 408 Request Timeout.
295    #[deprecated(
296        since = "0.13.0",
297        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
298    )]
299    pub request_timeout: Duration,
300    /// Graceful shutdown timeout. Default: 30s.
301    /// After the shutdown signal, in-flight requests have this long to finish.
302    #[deprecated(
303        since = "0.13.0",
304        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
305    )]
306    pub shutdown_timeout: Duration,
307    /// Idle timeout for MCP sessions. Sessions with no activity for this
308    /// duration are closed automatically. Default: 20 minutes.
309    #[deprecated(
310        since = "0.13.0",
311        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
312    )]
313    pub session_idle_timeout: Duration,
314    /// Interval for SSE keep-alive pings. Prevents proxies and load
315    /// balancers from killing idle connections. Default: 15 seconds.
316    #[deprecated(
317        since = "0.13.0",
318        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
319    )]
320    pub sse_keep_alive: Duration,
321    /// Callback invoked once the server is built, delivering a
322    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
323    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
324    #[deprecated(
325        since = "0.13.0",
326        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
327    )]
328    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
329    /// Additional application-specific routes merged into the top-level
330    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
331    /// so the application is responsible for its own auth on them.
332    /// Handlers can extract [`PeerAddr`] (or
333    /// [`axum::extract::ConnectInfo`]`<SocketAddr>` for third-party
334    /// middleware compatibility) regardless of whether TLS is enabled.
335    #[deprecated(
336        since = "0.13.0",
337        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
338    )]
339    pub extra_router: Option<axum::Router>,
340    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
341    /// When set, OAuth metadata endpoints advertise this URL instead of
342    /// the listen address. Required when binding `0.0.0.0` behind a
343    /// reverse proxy or inside a container.
344    #[deprecated(
345        since = "0.13.0",
346        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
347    )]
348    pub public_url: Option<String>,
349    /// Log inbound HTTP request headers at DEBUG level.
350    /// Sensitive values remain redacted.
351    #[deprecated(
352        since = "0.13.0",
353        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
354    )]
355    pub log_request_headers: bool,
356    /// Enable gzip/br response compression on MCP responses.
357    /// Defaults to `false` to preserve existing behaviour.
358    #[deprecated(
359        since = "0.13.0",
360        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
361    )]
362    pub compression_enabled: bool,
363    /// Minimum response body size (in bytes) before compression kicks in.
364    /// Only used when `compression_enabled` is true. Default: 1024.
365    #[deprecated(
366        since = "0.13.0",
367        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
368    )]
369    pub compression_min_size: u16,
370    /// Global cap on in-flight HTTP requests across the whole server.
371    /// When `Some`, requests over the cap receive 503 Service Unavailable
372    /// via `tower::load_shed`. Default: `None` (unlimited).
373    #[deprecated(
374        since = "0.13.0",
375        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
376    )]
377    pub max_concurrent_requests: Option<usize>,
378    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
379    /// configured and `enabled`. Default: `false`.
380    #[deprecated(
381        since = "0.13.0",
382        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
383    )]
384    pub admin_enabled: bool,
385    /// RBAC role required to access admin endpoints. Default: `"admin"`.
386    #[deprecated(
387        since = "0.13.0",
388        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
389    )]
390    pub admin_role: String,
391    /// Enable Prometheus metrics endpoint on a separate listener.
392    /// Requires the `metrics` crate feature.
393    #[cfg(feature = "metrics")]
394    #[deprecated(
395        since = "0.13.0",
396        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
397    )]
398    pub metrics_enabled: bool,
399    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
400    #[cfg(feature = "metrics")]
401    #[deprecated(
402        since = "0.13.0",
403        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
404    )]
405    pub metrics_bind: String,
406    /// Per-header overrides for the OWASP security headers emitted by
407    /// the global response middleware. See [`SecurityHeadersConfig`]
408    /// for the three-state semantic and validation rules.
409    #[deprecated(
410        since = "1.5.0",
411        note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
412    )]
413    pub security_headers: SecurityHeadersConfig,
414    /// Per-handshake deadline on the TLS accept path. Idle or slow-loris
415    /// connections are dropped once it elapses. Default: 10s.
416    ///
417    /// Startup-only: bound at listener construction, NOT hot-reloadable
418    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
419    #[deprecated(
420        since = "1.9.0",
421        note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
422    )]
423    pub tls_handshake_timeout: Duration,
424    /// Cap on concurrently in-flight TLS handshakes. At saturation the
425    /// acceptor stops pulling new connections from the kernel backlog
426    /// (backpressure) instead of accepting and dropping. Default: 256.
427    ///
428    /// Startup-only: bound at listener construction, NOT hot-reloadable
429    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
430    #[deprecated(
431        since = "1.9.0",
432        note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
433    )]
434    pub max_concurrent_tls_handshakes: usize,
435}
436
437/// Marker that wraps a value proven to satisfy its validation
438/// contract.
439///
440/// The only way to obtain `Validated<McpServerConfig>` is by calling
441/// [`McpServerConfig::validate`], which is the contract enforced at
442/// the type level by [`serve`] and [`serve_with_listener`]. The
443/// inner field is private, so downstream code cannot bypass
444/// validation by hand-constructing the wrapper.
445///
446/// Use [`Validated::as_inner`] for read-only borrowing. To mutate,
447/// recover the raw value with [`Validated::into_inner`] and
448/// re-validate.
449///
450/// # Example
451///
452/// ```no_run
453/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
454/// use rmcp::handler::server::ServerHandler;
455/// use rmcp::model::{ServerCapabilities, ServerInfo};
456///
457/// #[derive(Clone)]
458/// struct H;
459/// impl ServerHandler for H {
460///     fn get_info(&self) -> ServerInfo {
461///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
462///     }
463/// }
464///
465/// # async fn example() -> rmcp_server_kit::Result<()> {
466/// let config: Validated<McpServerConfig> =
467///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
468/// serve(config, || H).await
469/// # }
470/// ```
471///
472/// Forgetting `.validate()?` is a compile error:
473///
474/// ```compile_fail
475/// use rmcp_server_kit::transport::{McpServerConfig, serve};
476/// use rmcp::handler::server::ServerHandler;
477/// use rmcp::model::{ServerCapabilities, ServerInfo};
478///
479/// #[derive(Clone)]
480/// struct H;
481/// impl ServerHandler for H {
482///     fn get_info(&self) -> ServerInfo {
483///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
484///     }
485/// }
486///
487/// # async fn example() -> rmcp_server_kit::Result<()> {
488/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
489/// // Missing `.validate()?` -> mismatched types: expected
490/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
491/// serve(config, || H).await
492/// # }
493/// ```
494#[allow(
495    missing_debug_implementations,
496    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
497)]
498pub struct Validated<T>(T);
499
500impl<T> std::fmt::Debug for Validated<T> {
501    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
502        f.debug_struct("Validated").finish_non_exhaustive()
503    }
504}
505
506impl<T> Validated<T> {
507    /// Borrow the inner value.
508    #[must_use]
509    pub fn as_inner(&self) -> &T {
510        &self.0
511    }
512
513    /// Recover the raw value, discarding the validation proof.
514    ///
515    /// Re-validate before re-using the value with [`serve`] or
516    /// [`serve_with_listener`].
517    #[must_use]
518    pub fn into_inner(self) -> T {
519        self.0
520    }
521}
522
523#[allow(
524    deprecated,
525    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
526)]
527impl McpServerConfig {
528    /// Create a new server configuration with the given bind address,
529    /// server name, and version. All other fields use safe defaults.
530    ///
531    /// Use the chainable `with_*` / `enable_*` builder methods to
532    /// customize. Call [`McpServerConfig::validate`] to obtain a
533    /// [`Validated<McpServerConfig>`] proof token, which is required by
534    /// [`serve`] and [`serve_with_listener`].
535    #[must_use]
536    pub fn new(
537        bind_addr: impl Into<String>,
538        name: impl Into<String>,
539        version: impl Into<String>,
540    ) -> Self {
541        Self {
542            bind_addr: bind_addr.into(),
543            name: name.into(),
544            version: version.into(),
545            tls_cert_path: None,
546            tls_key_path: None,
547            auth: None,
548            rbac: None,
549            allowed_origins: Vec::new(),
550            tool_rate_limit: None,
551            readiness_check: None,
552            max_request_body: 1024 * 1024,
553            request_timeout: Duration::from_mins(2),
554            shutdown_timeout: Duration::from_secs(30),
555            session_idle_timeout: Duration::from_mins(20),
556            sse_keep_alive: Duration::from_secs(15),
557            on_reload_ready: None,
558            extra_router: None,
559            public_url: None,
560            log_request_headers: false,
561            compression_enabled: false,
562            compression_min_size: 1024,
563            max_concurrent_requests: None,
564            admin_enabled: false,
565            admin_role: "admin".to_owned(),
566            #[cfg(feature = "metrics")]
567            metrics_enabled: false,
568            #[cfg(feature = "metrics")]
569            metrics_bind: "127.0.0.1:9090".into(),
570            security_headers: SecurityHeadersConfig::default(),
571            tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
572            max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
573        }
574    }
575
576    // ---------------------------------------------------------------
577    // Builder methods (fluent, consume + return self).
578    //
579    // Each method is `#[must_use]` because dropping the returned
580    // `McpServerConfig` discards the configuration change.
581    // ---------------------------------------------------------------
582
583    /// Attach an authentication configuration. Required for
584    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
585    #[must_use]
586    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
587        self.auth = Some(auth);
588        self
589    }
590
591    /// Override one or more of the OWASP security headers emitted on
592    /// every response. See [`SecurityHeadersConfig`] for the three-state
593    /// semantic (`None` = default, `Some("")` = omit, `Some(v)` =
594    /// override). Values are validated by [`Self::validate`].
595    #[must_use]
596    pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
597        self.security_headers = headers;
598        self
599    }
600
601    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
602    /// final port is only known after pre-binding an ephemeral listener
603    /// (tests, dynamic-port deployments).
604    #[must_use]
605    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
606        self.bind_addr = addr.into();
607        self
608    }
609
610    /// Attach an RBAC policy. Tool calls are checked against the policy
611    /// after authentication.
612    #[must_use]
613    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
614        self.rbac = Some(rbac);
615        self
616    }
617
618    /// Configure TLS by providing the certificate and private key paths
619    /// (PEM). Both must be readable at startup. Without this call, the
620    /// server runs plain HTTP.
621    #[must_use]
622    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
623        self.tls_cert_path = Some(cert_path.into());
624        self.tls_key_path = Some(key_path.into());
625        self
626    }
627
628    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
629    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
630    /// a container so OAuth metadata and auto-derived origins resolve correctly.
631    #[must_use]
632    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
633        self.public_url = Some(url.into());
634        self
635    }
636
637    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
638    /// When empty and [`with_public_url`](Self::with_public_url) is set,
639    /// the origin is auto-derived.
640    #[must_use]
641    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
642    where
643        I: IntoIterator<Item = S>,
644        S: Into<String>,
645    {
646        self.allowed_origins = origins.into_iter().map(Into::into).collect();
647        self
648    }
649
650    /// Merge an additional axum router at the top level. Routes added
651    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
652    /// for its own protection.
653    ///
654    /// To support that protection (e.g. per-IP rate limiting on
655    /// unauthenticated endpoints), every request served by [`serve`]
656    /// carries the client peer address regardless of whether TLS is
657    /// enabled: extract the framework-owned [`PeerAddr`] in your
658    /// handlers, or rely on [`axum::extract::ConnectInfo`]`<SocketAddr>`
659    /// for stock third-party middleware (e.g. per-IP rate-limit key
660    /// extractors). Neither extension exists under [`serve_stdio`],
661    /// which has no network peer.
662    #[must_use]
663    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
664        self.extra_router = Some(router);
665        self
666    }
667
668    /// Install an async readiness probe for `/readyz`. Without this call,
669    /// `/readyz` mirrors `/healthz` (always 200 OK).
670    #[must_use]
671    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
672        self.readiness_check = Some(check);
673        self
674    }
675
676    /// Override the maximum request body (bytes). Must be `> 0`.
677    /// Default: 1 MiB.
678    #[must_use]
679    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
680        self.max_request_body = bytes;
681        self
682    }
683
684    /// Override the per-request processing timeout. Default: 2 minutes.
685    #[must_use]
686    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
687        self.request_timeout = timeout;
688        self
689    }
690
691    /// Override the graceful shutdown grace period. Default: 30 seconds.
692    #[must_use]
693    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
694        self.shutdown_timeout = timeout;
695        self
696    }
697
698    /// Override the MCP session idle timeout. Default: 20 minutes.
699    #[must_use]
700    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
701        self.session_idle_timeout = timeout;
702        self
703    }
704
705    /// Override the SSE keep-alive interval. Default: 15 seconds.
706    #[must_use]
707    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
708        self.sse_keep_alive = interval;
709        self
710    }
711
712    /// Cap the global number of in-flight HTTP requests via
713    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
714    /// Default: unlimited.
715    #[must_use]
716    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
717        self.max_concurrent_requests = Some(limit);
718        self
719    }
720
721    /// Override the per-handshake deadline on the TLS accept path.
722    /// Idle or slow-loris connections are dropped once it elapses.
723    /// Default: 10s. Must be greater than zero.
724    ///
725    /// Startup-only: the value is bound at listener construction and is
726    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
727    /// is configured via [`Self::with_tls`].
728    #[must_use]
729    pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
730        self.tls_handshake_timeout = timeout;
731        self
732    }
733
734    /// Override the cap on concurrently in-flight TLS handshakes. At
735    /// saturation the acceptor stops pulling new connections from the
736    /// kernel backlog (backpressure) instead of accepting and dropping.
737    /// Default: 256. Must be greater than zero.
738    ///
739    /// Startup-only: the value is bound at listener construction and is
740    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
741    /// is configured via [`Self::with_tls`].
742    #[must_use]
743    pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
744        self.max_concurrent_tls_handshakes = limit;
745        self
746    }
747
748    /// Cap tool invocations per source IP per minute. Enforced on every
749    /// `tools/call` request.
750    #[must_use]
751    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
752        self.tool_rate_limit = Some(per_minute);
753        self
754    }
755
756    /// Register a callback that receives the [`ReloadHandle`] after the
757    /// server is built. Use it to wire SIGHUP-style hot reloads of API
758    /// keys and RBAC policy.
759    #[must_use]
760    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
761    where
762        F: FnOnce(ReloadHandle) + Send + 'static,
763    {
764        self.on_reload_ready = Some(Box::new(callback));
765        self
766    }
767
768    /// Enable gzip/brotli response compression on MCP responses.
769    /// `min_size` is the smallest body size (bytes) eligible for
770    /// compression. Default min size: 1024.
771    #[must_use]
772    pub fn enable_compression(mut self, min_size: u16) -> Self {
773        self.compression_enabled = true;
774        self.compression_min_size = min_size;
775        self
776    }
777
778    /// Enable `/admin/*` diagnostic endpoints. Requires
779    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
780    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
781    /// role gate (default: `"admin"`).
782    #[must_use]
783    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
784        self.admin_enabled = true;
785        self.admin_role = role.into();
786        self
787    }
788
789    /// Log inbound HTTP request headers at DEBUG level. Sensitive
790    /// values remain redacted by the logging layer.
791    #[must_use]
792    pub fn enable_request_header_logging(mut self) -> Self {
793        self.log_request_headers = true;
794        self
795    }
796
797    /// Enable the Prometheus metrics listener on `bind` (e.g.
798    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
799    #[cfg(feature = "metrics")]
800    #[must_use]
801    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
802        self.metrics_enabled = true;
803        self.metrics_bind = bind.into();
804        self
805    }
806
807    /// Validate the configuration and consume `self`, returning a
808    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
809    /// and [`serve_with_listener`]. This is the only way to construct
810    /// `Validated<McpServerConfig>`, so the type system guarantees
811    /// validation has run before the server starts.
812    ///
813    /// Checks:
814    ///
815    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
816    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
817    ///    be unset.
818    /// 3. `bind_addr` must parse as a [`SocketAddr`].
819    /// 4. `public_url`, when set, must start with `http://` or `https://`.
820    /// 5. Each entry in `allowed_origins` must start with `http://` or
821    ///    `https://`.
822    /// 6. `max_request_body` must be greater than zero.
823    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
824    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
825    ///    `proxy.token_url`, `proxy.introspection_url`,
826    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
827    ///    and use the `https` scheme. Set
828    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
829    ///    targets (strongly discouraged in production - see the field-level
830    ///    docs for the threat model).
831    ///
832    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
833    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
834    ///
835    /// # Errors
836    ///
837    /// Returns [`McpxError::Config`] with a human-readable message on
838    /// the first validation failure.
839    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
840        self.check()?;
841        Ok(Validated(self))
842    }
843
844    /// Run the validation checks without consuming `self`. Used by
845    /// internal call sites (e.g. tests) that need to inspect a config
846    /// without taking ownership.
847    fn check(&self) -> Result<(), McpxError> {
848        // 1. admin <-> auth dependency. Mirrors the runtime check in
849        //    `build_app_router`: admin endpoints require an auth state,
850        //    which is built only when `auth` is `Some` *and* `enabled`.
851        if self.admin_enabled {
852            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
853            if !auth_enabled {
854                return Err(McpxError::Config(
855                    "admin_enabled=true requires auth to be configured and enabled".into(),
856                ));
857            }
858        }
859
860        // 2. TLS cert / key must be paired
861        match (&self.tls_cert_path, &self.tls_key_path) {
862            (Some(_), None) => {
863                return Err(McpxError::Config(
864                    "tls_cert_path is set but tls_key_path is missing".into(),
865                ));
866            }
867            (None, Some(_)) => {
868                return Err(McpxError::Config(
869                    "tls_key_path is set but tls_cert_path is missing".into(),
870                ));
871            }
872            _ => {}
873        }
874
875        // 3. bind_addr parses
876        if self.bind_addr.parse::<SocketAddr>().is_err() {
877            return Err(McpxError::Config(format!(
878                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
879                self.bind_addr
880            )));
881        }
882
883        // 4. public_url scheme
884        if let Some(ref url) = self.public_url
885            && !(url.starts_with("http://") || url.starts_with("https://"))
886        {
887            return Err(McpxError::Config(format!(
888                "public_url {url:?} must start with http:// or https://"
889            )));
890        }
891
892        // 5. allowed_origins scheme
893        for origin in &self.allowed_origins {
894            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
895                return Err(McpxError::Config(format!(
896                    "allowed_origins entry {origin:?} must start with http:// or https://"
897                )));
898            }
899        }
900
901        // 6. max_request_body > 0
902        if self.max_request_body == 0 {
903            return Err(McpxError::Config(
904                "max_request_body must be greater than zero".into(),
905            ));
906        }
907
908        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
909        #[cfg(feature = "oauth")]
910        if let Some(auth_cfg) = &self.auth
911            && let Some(oauth_cfg) = &auth_cfg.oauth
912        {
913            oauth_cfg.validate()?;
914        }
915
916        // 8. Security-header overrides parse as valid HTTP header values,
917        //    and HSTS does not smuggle in a `preload` directive.
918        validate_security_headers(&self.security_headers)?;
919
920        // 9. max_concurrent_requests must be > 0 when set. Zero would
921        //    deadlock the global concurrency limiter and reject every
922        //    request. Mirrors the TOML-side check in `src/config.rs`.
923        if let Some(0) = self.max_concurrent_requests {
924            return Err(McpxError::Config(
925                "max_concurrent_requests must be greater than zero when set".into(),
926            ));
927        }
928
929        // 10. Auth rate-limit `max_tracked_keys` must be > 0. A zero cap
930        //     would force `BoundedKeyedLimiter` to evict on every insert
931        //     and effectively disable rate limiting.
932        if let Some(auth_cfg) = &self.auth
933            && let Some(rl) = &auth_cfg.rate_limit
934            && rl.max_tracked_keys == 0
935        {
936            return Err(McpxError::Config(
937                "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
938            ));
939        }
940
941        // 11. tls_handshake_timeout must be > 0. A zero deadline would
942        //     reap every handshake before it could complete, rejecting
943        //     all TLS connections. Mirrors the TOML-side check in
944        //     `src/config.rs`.
945        if self.tls_handshake_timeout == Duration::ZERO {
946            return Err(McpxError::Config(
947                "tls_handshake_timeout must be greater than zero".into(),
948            ));
949        }
950
951        // 12. max_concurrent_tls_handshakes must be > 0. A zero-permit
952        //     semaphore would never admit a handshake, deadlocking the
953        //     TLS accept path. Mirrors the TOML-side check in
954        //     `src/config.rs`.
955        if self.max_concurrent_tls_handshakes == 0 {
956            return Err(McpxError::Config(
957                "max_concurrent_tls_handshakes must be greater than zero".into(),
958            ));
959        }
960
961        Ok(())
962    }
963}
964
965/// Handle for hot-reloading server configuration without restart.
966///
967/// Obtained via [`McpServerConfig::on_reload_ready`].
968/// All swap operations are lock-free and wait-free -- in-flight requests
969/// finish with the old values while new requests see the update immediately.
970#[allow(
971    missing_debug_implementations,
972    reason = "contains Arc<AuthState> with non-Debug fields"
973)]
974pub struct ReloadHandle {
975    auth: Option<Arc<AuthState>>,
976    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
977    crl_set: Option<Arc<CrlSet>>,
978}
979
980impl ReloadHandle {
981    /// Atomically replace the API key list used by the auth middleware.
982    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
983        if let Some(ref auth) = self.auth {
984            auth.reload_keys(keys);
985        }
986    }
987
988    /// Atomically replace the RBAC policy used by the RBAC middleware.
989    pub fn reload_rbac(&self, policy: RbacPolicy) {
990        if let Some(ref rbac) = self.rbac {
991            rbac.store(Arc::new(policy));
992            tracing::info!("RBAC policy reloaded");
993        }
994    }
995
996    /// Force an immediate refresh of all cached mTLS CRLs.
997    ///
998    /// # Errors
999    ///
1000    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
1001    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
1002        let Some(ref crl_set) = self.crl_set else {
1003            return Err(McpxError::Config(
1004                "CRL refresh requested but mTLS CRL support is not configured".into(),
1005            ));
1006        };
1007
1008        crl_set.force_refresh().await
1009    }
1010}
1011
1012/// Generic MCP HTTP server.
1013///
1014/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
1015/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
1016/// with TLS (rustls). Optionally supports mTLS client certificate auth.
1017///
1018/// # Errors
1019///
1020/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
1021/// or the server fails.
1022// NOTE: cognitive complexity reduced from 111/25 to 83/25 by
1023// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
1024// Remaining flow is a linear router builder: middleware layering, feature-
1025// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
1026// would require threading many `&mut Router` helpers and hurt readability
1027// of the layer order (which is security-relevant and must stay visible).
1028#[allow(
1029    clippy::too_many_lines,
1030    clippy::cognitive_complexity,
1031    reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
1032)]
1033/// Internal bundle of values produced by [`build_app_router`] and
1034/// consumed by [`serve`] / [`serve_with_listener`] when driving the
1035/// HTTP listener.
1036struct AppRunParams {
1037    /// TLS cert/key paths when TLS is configured.
1038    tls_paths: Option<(PathBuf, PathBuf)>,
1039    /// Per-handshake deadline on the TLS accept path.
1040    tls_handshake_timeout: Duration,
1041    /// Cap on concurrently in-flight TLS handshakes.
1042    max_concurrent_tls_handshakes: usize,
1043    /// mTLS configuration when mutual-TLS auth is enabled.
1044    mtls_config: Option<MtlsConfig>,
1045    /// Graceful shutdown drain window.
1046    shutdown_timeout: Duration,
1047    /// Shared auth state used by hot-reload callbacks.
1048    auth_state: Option<Arc<AuthState>>,
1049    /// Hot-reloadable RBAC state used by reload callbacks.
1050    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1051    /// Optional callback that receives the final [`ReloadHandle`].
1052    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1053    /// Server-internal cancellation token. Cancelled by [`run_server`]
1054    /// once the shutdown trigger fires (so rmcp's child token also
1055    /// fires, terminating in-flight MCP sessions).
1056    ct: CancellationToken,
1057    /// `"http"` or `"https"` -- used only for boot-time logging.
1058    scheme: &'static str,
1059    /// Server name -- used only for boot-time logging.
1060    name: String,
1061}
1062
1063/// Build the full application axum [`axum::Router`] (MCP route +
1064/// middleware stack + admin + OAuth + health endpoints + security
1065/// headers + CORS + compression + concurrency limit + origin check)
1066/// and the [`AppRunParams`] needed to drive it.
1067///
1068/// This is the shared core of [`serve`] and [`serve_with_listener`].
1069/// It performs *no* network I/O: callers are responsible for binding
1070/// (or accepting a pre-bound) [`TcpListener`] and invoking
1071/// [`run_server`].
1072#[allow(
1073    clippy::cognitive_complexity,
1074    reason = "router assembly is intrinsically sequential; splitting harms readability"
1075)]
1076#[allow(
1077    deprecated,
1078    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
1079)]
1080fn build_app_router<H, F>(
1081    mut config: McpServerConfig,
1082    handler_factory: F,
1083) -> anyhow::Result<(axum::Router, AppRunParams)>
1084where
1085    H: ServerHandler + 'static,
1086    F: Fn() -> H + Send + Sync + Clone + 'static,
1087{
1088    let ct = CancellationToken::new();
1089
1090    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
1091    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
1092
1093    let mcp_service = StreamableHttpService::new(
1094        move || Ok(handler_factory()),
1095        {
1096            let mut mgr = LocalSessionManager::default();
1097            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
1098            mgr.into()
1099        },
1100        StreamableHttpServerConfig::default()
1101            .with_allowed_hosts(allowed_hosts)
1102            .with_sse_keep_alive(Some(config.sse_keep_alive))
1103            .with_cancellation_token(ct.child_token()),
1104    );
1105
1106    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
1107    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1108
1109    // Build auth state eagerly when auth is configured so we can wire both
1110    // the auth middleware *and* the optional admin router against the same
1111    // state. The middleware itself is installed further down in layer order.
1112    let auth_state: Option<Arc<AuthState>> = match config.auth {
1113        Some(ref auth_config) if auth_config.enabled => {
1114            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1115            let pre_auth_limiter = auth_config
1116                .rate_limit
1117                .as_ref()
1118                .map(crate::auth::build_pre_auth_limiter);
1119
1120            #[cfg(feature = "oauth")]
1121            let jwks_cache = auth_config
1122                .oauth
1123                .as_ref()
1124                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1125                .transpose()
1126                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1127
1128            Some(Arc::new(AuthState {
1129                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1130                rate_limiter,
1131                pre_auth_limiter,
1132                #[cfg(feature = "oauth")]
1133                jwks_cache,
1134                seen_identities: crate::auth::SeenIdentitySet::new(),
1135                counters: crate::auth::AuthCounters::default(),
1136            }))
1137        }
1138        _ => None,
1139    };
1140
1141    // Build the RBAC policy swap early so the admin router and the later
1142    // RBAC middleware layer share the same hot-reloadable state.
1143    let rbac_swap = Arc::new(ArcSwap::new(
1144        config
1145            .rbac
1146            .clone()
1147            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1148    ));
1149
1150    // Optional /admin/* diagnostic routes. Merged BEFORE the
1151    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
1152    if config.admin_enabled {
1153        let Some(ref auth_state_ref) = auth_state else {
1154            return Err(anyhow::anyhow!(
1155                "admin_enabled=true requires auth to be configured and enabled"
1156            ));
1157        };
1158        let admin_state = crate::admin::AdminState {
1159            started_at: std::time::Instant::now(),
1160            name: config.name.clone(),
1161            version: config.version.clone(),
1162            auth: Some(Arc::clone(auth_state_ref)),
1163            rbac: Arc::clone(&rbac_swap),
1164        };
1165        let admin_cfg = crate::admin::AdminConfig {
1166            role: config.admin_role.clone(),
1167        };
1168        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1169        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1170    }
1171
1172    // ----- Middleware order (CRITICAL: read carefully) ------------------
1173    //
1174    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
1175    // added is the OUTERMOST (runs first on a request). To achieve a
1176    // request-time flow of:
1177    //
1178    //   body-limit -> timeout -> auth -> rbac -> handler
1179    //
1180    // we add layers in the REVERSE order:
1181    //
1182    //   1. RBAC               (innermost, runs last before handler)
1183    //   2. auth               (parses identity, sets extension for RBAC)
1184    //   3. timeout            (bounds total request time)
1185    //   4. body-limit         (outermost on /mcp; caps payload before
1186    //                          anything else reads/buffers it)
1187    //
1188    // Origin validation is installed on the OUTER router (after the
1189    // /mcp router is merged in), so it also protects /healthz, /readyz,
1190    // /version, and any OAuth proxy endpoints.
1191    //
1192    // Rationale:
1193    // - Body-limit must be outermost on /mcp so RBAC (which reads the
1194    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
1195    // - Auth must run before RBAC because RBAC consumes
1196    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
1197    //   policy.
1198    // - Origin runs before auth so we reject cross-origin requests
1199    //   without spending Argon2 cycles on unauthenticated callers.
1200
1201    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
1202    // Always installed: even when RBAC is disabled, tool rate limiting may
1203    // be active (MCP spec: servers MUST rate limit tool invocations).
1204    {
1205        let tool_limiter: Option<Arc<ToolRateLimiter>> =
1206            config.tool_rate_limit.map(build_tool_rate_limiter);
1207
1208        if rbac_swap.load().is_enabled() {
1209            tracing::info!("RBAC enforcement enabled on /mcp");
1210        }
1211        if let Some(limit) = config.tool_rate_limit {
1212            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1213        }
1214
1215        let rbac_for_mw = Arc::clone(&rbac_swap);
1216        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1217            let p = rbac_for_mw.load_full();
1218            let tl = tool_limiter.clone();
1219            rbac_middleware(p, tl, req, next)
1220        }));
1221    }
1222
1223    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
1224    if let Some(ref auth_config) = config.auth
1225        && auth_config.enabled
1226    {
1227        let Some(ref state) = auth_state else {
1228            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1229        };
1230
1231        let methods: Vec<&str> = [
1232            auth_config.mtls.is_some().then_some("mTLS"),
1233            (!auth_config.api_keys.is_empty()).then_some("bearer"),
1234            #[cfg(feature = "oauth")]
1235            auth_config.oauth.is_some().then_some("oauth-jwt"),
1236        ]
1237        .into_iter()
1238        .flatten()
1239        .collect();
1240
1241        tracing::info!(
1242            methods = %methods.join(", "),
1243            api_keys = auth_config.api_keys.len(),
1244            "auth enabled on /mcp"
1245        );
1246
1247        let state_for_mw = Arc::clone(state);
1248        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1249            let s = Arc::clone(&state_for_mw);
1250            auth_middleware(s, req, next)
1251        }));
1252    }
1253
1254    // [3] Request timeout (returns 408 on expiry). Bounds total request
1255    // duration including auth + handler.
1256    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1257        axum::http::StatusCode::REQUEST_TIMEOUT,
1258        config.request_timeout,
1259    ));
1260
1261    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
1262    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
1263    // attempts to buffer or parse the body.
1264    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1265        config.max_request_body,
1266    ));
1267
1268    // Compute the effective allowed-origins list for the outer
1269    // origin-check layer (installed on the merged router below). When
1270    // `allowed_origins` is empty but `public_url` is set, auto-derive
1271    // the origin from the public URL so MCP clients (e.g. Claude Code)
1272    // that send `Origin: <server-url>` are accepted without explicit
1273    // config.
1274    let mut effective_origins = config.allowed_origins.clone();
1275    if effective_origins.is_empty()
1276        && let Some(ref url) = config.public_url
1277    {
1278        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1279        // Strip any path/query from the public URL. Offsets come from
1280        // `find`, so they are char-boundary-aligned; `get(..)` keeps that
1281        // machine-checked (a violation degrades to an empty slice).
1282        if let Some(scheme_end) = url.find("://") {
1283            let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1284            let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1285            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1286            let host = after_scheme.get(..host_end).unwrap_or_default();
1287            let origin = format!("{scheme_with_sep}{host}");
1288            tracing::info!(
1289                %origin,
1290                "auto-derived allowed origin from public_url"
1291            );
1292            effective_origins.push(origin);
1293        }
1294    }
1295    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1296    let cors_origins = Arc::clone(&allowed_origins);
1297    let log_request_headers = config.log_request_headers;
1298
1299    let readyz_route = if let Some(check) = config.readiness_check.take() {
1300        axum::routing::get(move || readyz(Arc::clone(&check)))
1301    } else {
1302        axum::routing::get(healthz)
1303    };
1304
1305    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1306    let mut router = axum::Router::new()
1307        .route("/healthz", axum::routing::get(healthz))
1308        .route("/readyz", readyz_route)
1309        .route(
1310            "/version",
1311            axum::routing::get({
1312                // Pre-serialize the version payload once at router-build
1313                // time. The handler then serves a cheap `Arc::clone` of the
1314                // immutable bytes per request, avoiding `serde_json::Value`
1315                // allocation + serialization on every `/version` hit.
1316                let payload_bytes: Arc<[u8]> =
1317                    serialize_version_payload(&config.name, &config.version);
1318                move || {
1319                    let p = Arc::clone(&payload_bytes);
1320                    async move {
1321                        (
1322                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1323                            p.to_vec(),
1324                        )
1325                    }
1326                }
1327            }),
1328        )
1329        .merge(mcp_router);
1330
1331    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1332    if let Some(extra) = config.extra_router.take() {
1333        router = router.merge(extra);
1334    }
1335
1336    // RFC 9728: Protected Resource Metadata endpoint.
1337    // When OAuth is configured, serve full metadata with authorization_servers.
1338    // Otherwise, serve a minimal document with just the resource URL and no
1339    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1340    // that the server exists but does NOT require OAuth authentication,
1341    // preventing them from gating the connection behind a broken auth flow.
1342    let server_url = if let Some(ref url) = config.public_url {
1343        url.trim_end_matches('/').to_owned()
1344    } else {
1345        let prm_scheme = if config.tls_cert_path.is_some() {
1346            "https"
1347        } else {
1348            "http"
1349        };
1350        format!("{prm_scheme}://{}", config.bind_addr)
1351    };
1352    let resource_url = format!("{server_url}/mcp");
1353
1354    #[cfg(feature = "oauth")]
1355    let prm_metadata = if let Some(ref auth_config) = config.auth
1356        && let Some(ref oauth_config) = auth_config.oauth
1357    {
1358        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1359    } else {
1360        serde_json::json!({ "resource": resource_url })
1361    };
1362    #[cfg(not(feature = "oauth"))]
1363    let prm_metadata = serde_json::json!({ "resource": resource_url });
1364
1365    router = router.route(
1366        "/.well-known/oauth-protected-resource",
1367        axum::routing::get(move || {
1368            let m = prm_metadata.clone();
1369            async move { axum::Json(m) }
1370        }),
1371    );
1372
1373    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1374    // /authorize, /token, /register, and authorization server metadata so
1375    // MCP clients can perform Authorization Code + PKCE against the upstream
1376    // IdP (e.g. Keycloak) transparently.
1377    #[cfg(feature = "oauth")]
1378    if let Some(ref auth_config) = config.auth
1379        && let Some(ref oauth_config) = auth_config.oauth
1380        && oauth_config.proxy.is_some()
1381    {
1382        router =
1383            install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1384    }
1385
1386    // OWASP security response headers (applied to all responses).
1387    // HSTS is conditional on TLS being configured.
1388    let is_tls = config.tls_cert_path.is_some();
1389    let security_headers_cfg = Arc::new(config.security_headers.clone());
1390    router = router.layer(axum::middleware::from_fn(move |req, next| {
1391        let cfg = Arc::clone(&security_headers_cfg);
1392        security_headers_middleware(is_tls, cfg, req, next)
1393    }));
1394
1395    // CORS preflight layer (required for browser-based MCP clients).
1396    // Uses the same effective origins as the origin check middleware
1397    // (including auto-derived origin from public_url).
1398    if !cors_origins.is_empty() {
1399        let cors = tower_http::cors::CorsLayer::new()
1400            .allow_origin(
1401                cors_origins
1402                    .iter()
1403                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1404                    .collect::<Vec<_>>(),
1405            )
1406            .allow_methods([
1407                axum::http::Method::GET,
1408                axum::http::Method::POST,
1409                axum::http::Method::OPTIONS,
1410            ])
1411            .allow_headers([
1412                axum::http::header::CONTENT_TYPE,
1413                axum::http::header::AUTHORIZATION,
1414            ]);
1415        router = router.layer(cors);
1416    }
1417
1418    // Optional response compression (gzip + brotli). Skips small bodies
1419    // to avoid overhead. Applied after CORS so preflight responses remain
1420    // uncompressed.
1421    if config.compression_enabled {
1422        use tower_http::compression::Predicate as _;
1423        let predicate = tower_http::compression::DefaultPredicate::new().and(
1424            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1425        );
1426        router = router.layer(
1427            tower_http::compression::CompressionLayer::new()
1428                .gzip(true)
1429                .br(true)
1430                .compress_when(predicate),
1431        );
1432        tracing::info!(
1433            min_size = config.compression_min_size,
1434            "response compression enabled (gzip, br)"
1435        );
1436    }
1437
1438    // Optional global concurrency cap. `load_shed` converts the
1439    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1440    if let Some(max) = config.max_concurrent_requests {
1441        let overload_handler = tower::ServiceBuilder::new()
1442            .layer(axum::error_handling::HandleErrorLayer::new(
1443                |_err: tower::BoxError| async {
1444                    (
1445                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1446                        axum::Json(serde_json::json!({
1447                            "error": "overloaded",
1448                            "error_description": "server is at capacity, retry later"
1449                        })),
1450                    )
1451                },
1452            ))
1453            .layer(tower::load_shed::LoadShedLayer::new())
1454            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1455        router = router.layer(overload_handler);
1456        tracing::info!(max, "global concurrency limit enabled");
1457    }
1458
1459    // JSON fallback for unmatched routes. Without this, axum returns
1460    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1461    // when they probe OAuth endpoints like /authorize or /token.
1462    router = router.fallback(|| async {
1463        (
1464            axum::http::StatusCode::NOT_FOUND,
1465            axum::Json(serde_json::json!({
1466                "error": "not_found",
1467                "error_description": "The requested endpoint does not exist"
1468            })),
1469        )
1470    });
1471
1472    // Prometheus metrics: recording middleware + separate listener.
1473    #[cfg(feature = "metrics")]
1474    if config.metrics_enabled {
1475        let metrics = Arc::new(
1476            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1477        );
1478        let m = Arc::clone(&metrics);
1479        router = router.layer(axum::middleware::from_fn(
1480            move |req: Request<Body>, next: Next| {
1481                let m = Arc::clone(&m);
1482                metrics_middleware(m, req, next)
1483            },
1484        ));
1485        let metrics_bind = config.metrics_bind.clone();
1486        let metrics_shutdown = ct.clone();
1487        tokio::spawn(async move {
1488            if let Err(e) =
1489                crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1490            {
1491                tracing::error!("metrics listener failed: {e}");
1492            }
1493        });
1494    }
1495
1496    // Peer-address normalization. Mirrors the TLS branch's peer address
1497    // into `ConnectInfo<SocketAddr>` and exposes the framework-owned
1498    // `PeerAddr` extension on both listener branches, so ALL routes on
1499    // the merged router (`/mcp`, `/healthz`, OAuth proxy endpoints,
1500    // admin endpoints, extra_router, ...) and all inner middleware see a
1501    // uniform peer-address contract regardless of TLS. Installed just
1502    // inside the origin check, which stays outermost by design.
1503    router = router.layer(axum::middleware::from_fn(normalize_peer_addr_middleware));
1504
1505    // Origin validation layer (MCP spec: servers MUST validate the
1506    // Origin header to prevent DNS rebinding attacks). Installed as the
1507    // OUTERMOST layer on the OUTER router so it protects ALL routes
1508    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1509    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1510    // reject cross-origin attackers without spending Argon2 cycles.
1511    //
1512    // Origin-less requests (e.g. server-to-server probes, curl, native
1513    // MCP clients) are permitted; only requests with an Origin header
1514    // that does not match `effective_origins` are rejected.
1515    router = router.layer(axum::middleware::from_fn(move |req, next| {
1516        let origins = Arc::clone(&allowed_origins);
1517        origin_check_middleware(origins, log_request_headers, req, next)
1518    }));
1519
1520    let scheme = if config.tls_cert_path.is_some() {
1521        "https"
1522    } else {
1523        "http"
1524    };
1525
1526    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1527        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1528        _ => None,
1529    };
1530    let tls_handshake_timeout = config.tls_handshake_timeout;
1531    let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1532    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1533
1534    Ok((
1535        router,
1536        AppRunParams {
1537            tls_paths,
1538            tls_handshake_timeout,
1539            max_concurrent_tls_handshakes,
1540            mtls_config,
1541            shutdown_timeout: config.shutdown_timeout,
1542            auth_state,
1543            rbac_swap,
1544            on_reload_ready: config.on_reload_ready.take(),
1545            ct,
1546            scheme,
1547            name: config.name.clone(),
1548        },
1549    ))
1550}
1551
1552/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1553/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1554///
1555/// This is the standard entry point for production deployments. For
1556/// deterministic shutdown control (e.g. integration tests), see
1557/// [`serve_with_listener`].
1558///
1559/// The configuration must be validated first via
1560/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1561/// token. This typestate guarantees, at compile time, that the server
1562/// never starts with an invalid configuration.
1563///
1564/// # Errors
1565///
1566/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1567/// fails, or if the underlying axum server returns an error.
1568pub async fn serve<H, F>(
1569    config: Validated<McpServerConfig>,
1570    handler_factory: F,
1571) -> Result<(), McpxError>
1572where
1573    H: ServerHandler + 'static,
1574    F: Fn() -> H + Send + Sync + Clone + 'static,
1575{
1576    let config = config.into_inner();
1577    #[allow(
1578        deprecated,
1579        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1580    )]
1581    let bind_addr = config.bind_addr.clone();
1582    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1583
1584    let listener = TcpListener::bind(&bind_addr)
1585        .await
1586        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1587    log_listening(&params.name, params.scheme, &bind_addr);
1588
1589    run_server(
1590        router,
1591        listener,
1592        params.tls_paths,
1593        params.tls_handshake_timeout,
1594        params.max_concurrent_tls_handshakes,
1595        params.mtls_config,
1596        params.shutdown_timeout,
1597        params.auth_state,
1598        params.rbac_swap,
1599        params.on_reload_ready,
1600        params.ct,
1601    )
1602    .await
1603    .map_err(anyhow_to_startup)
1604}
1605
1606/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1607/// readiness signalling and external shutdown control.
1608///
1609/// This variant is intended for **deterministic integration tests** and
1610/// for embedders that need to bind the listening socket themselves
1611/// (e.g. systemd socket activation). Compared to [`serve`]:
1612///
1613/// * The caller passes a `TcpListener` that is already bound. This
1614///   eliminates the bind race in tests that previously required
1615///   poll-the-`/healthz`-loop start-up detection.
1616/// * `ready_tx`, when `Some`, receives the socket's
1617///   [`SocketAddr`] *after* the router is built and immediately before
1618///   the server starts accepting connections. Tests can `await` the
1619///   matching `oneshot::Receiver` to know exactly when it is safe to
1620///   issue requests.
1621/// * `shutdown`, when `Some`, gives the caller a
1622///   [`CancellationToken`] that triggers the same graceful-shutdown
1623///   path as a real OS signal. This avoids cross-platform issues with
1624///   sending real `SIGTERM` from tests on Windows.
1625///
1626/// All three optional parameters degrade gracefully: if `ready_tx` is
1627/// `None`, no signal is sent; if `shutdown` is `None`, the server only
1628/// stops on an OS signal (just like [`serve`]).
1629///
1630/// # Errors
1631///
1632/// Returns [`McpxError::Startup`] if router construction fails, if reading
1633/// the listener's `local_addr()` fails, or if the underlying axum
1634/// server returns an error.
1635pub async fn serve_with_listener<H, F>(
1636    listener: TcpListener,
1637    config: Validated<McpServerConfig>,
1638    handler_factory: F,
1639    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1640    shutdown: Option<CancellationToken>,
1641) -> Result<(), McpxError>
1642where
1643    H: ServerHandler + 'static,
1644    F: Fn() -> H + Send + Sync + Clone + 'static,
1645{
1646    let config = config.into_inner();
1647    let local_addr = listener
1648        .local_addr()
1649        .map_err(|e| io_to_startup("listener.local_addr", e))?;
1650    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1651
1652    log_listening(&params.name, params.scheme, &local_addr.to_string());
1653
1654    // Forward external shutdown into the server-internal cancellation
1655    // token so `run_server`'s shutdown trigger picks it up alongside
1656    // any real OS signal.
1657    if let Some(external) = shutdown {
1658        let internal = params.ct.clone();
1659        tokio::spawn(async move {
1660            external.cancelled().await;
1661            internal.cancel();
1662        });
1663    }
1664
1665    // Signal readiness *after* the router is fully built and external
1666    // shutdown is wired, but *before* run_server takes ownership of
1667    // the listener. The receiver can immediately issue requests.
1668    if let Some(tx) = ready_tx {
1669        // Receiver may have been dropped (test gave up). That's fine.
1670        let _ = tx.send(local_addr);
1671    }
1672
1673    run_server(
1674        router,
1675        listener,
1676        params.tls_paths,
1677        params.tls_handshake_timeout,
1678        params.max_concurrent_tls_handshakes,
1679        params.mtls_config,
1680        params.shutdown_timeout,
1681        params.auth_state,
1682        params.rbac_swap,
1683        params.on_reload_ready,
1684        params.ct,
1685    )
1686    .await
1687    .map_err(anyhow_to_startup)
1688}
1689
1690/// Emit the standard "listening on …" log lines used by both
1691/// [`serve`] and [`serve_with_listener`].
1692#[allow(
1693    clippy::cognitive_complexity,
1694    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1695)]
1696fn log_listening(name: &str, scheme: &str, addr: &str) {
1697    tracing::info!("{name} listening on {addr}");
1698    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
1699    tracing::info!("  Health check: {scheme}://{addr}/healthz");
1700    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
1701}
1702
1703/// Drive the chosen axum server variant (TLS or plain) with a graceful
1704/// shutdown window. Consumes the router and listener.
1705///
1706/// # Shutdown semantics
1707///
1708/// A single shutdown trigger (the FIRST of: OS signal via
1709/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
1710///
1711/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
1712///    accepting new connections and waits for in-flight requests to
1713///    drain;
1714/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
1715///    drainage exceeds `shutdown_timeout`.
1716///
1717/// Previously this function awaited `shutdown_signal()` independently
1718/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
1719/// resolves once per future and consumes one signal, the force-exit
1720/// timer was tied to a SECOND signal (a second SIGTERM the operator
1721/// would never send). Under a single SIGTERM the graceful drain could
1722/// hang indefinitely. The current implementation derives both branches
1723/// from a single shared trigger so the timeout race is anchored to the
1724/// FIRST (and only) signal.
1725#[allow(
1726    clippy::too_many_arguments,
1727    clippy::cognitive_complexity,
1728    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1729)]
1730async fn run_server(
1731    router: axum::Router,
1732    listener: TcpListener,
1733    tls_paths: Option<(PathBuf, PathBuf)>,
1734    tls_handshake_timeout: Duration,
1735    max_concurrent_tls_handshakes: usize,
1736    mtls_config: Option<MtlsConfig>,
1737    shutdown_timeout: Duration,
1738    auth_state: Option<Arc<AuthState>>,
1739    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1740    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1741    ct: CancellationToken,
1742) -> anyhow::Result<()> {
1743    // `shutdown_trigger` fires when the FIRST source resolves: either
1744    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
1745    // (which the test harness uses for deterministic shutdown).
1746    let shutdown_trigger = CancellationToken::new();
1747    {
1748        let trigger = shutdown_trigger.clone();
1749        let parent = ct.clone();
1750        tokio::spawn(async move {
1751            tokio::select! {
1752                () = shutdown_signal() => {}
1753                () = parent.cancelled() => {}
1754            }
1755            trigger.cancel();
1756        });
1757    }
1758
1759    let graceful = {
1760        let trigger = shutdown_trigger.clone();
1761        let ct = ct.clone();
1762        async move {
1763            trigger.cancelled().await;
1764            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1765            ct.cancel();
1766        }
1767    };
1768
1769    let force_exit_timer = {
1770        let trigger = shutdown_trigger.clone();
1771        async move {
1772            trigger.cancelled().await;
1773            tokio::time::sleep(shutdown_timeout).await;
1774        }
1775    };
1776
1777    if let Some((cert_path, key_path)) = tls_paths {
1778        let crl_set = if let Some(mtls) = mtls_config.as_ref()
1779            && mtls.crl_enabled
1780        {
1781            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1782            let (crl_set, discover_rx) =
1783                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1784                    .await
1785                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1786            tokio::spawn(mtls_revocation::run_crl_refresher(
1787                Arc::clone(&crl_set),
1788                discover_rx,
1789                ct.clone(),
1790            ));
1791            Some(crl_set)
1792        } else {
1793            None
1794        };
1795
1796        if let Some(cb) = on_reload_ready.take() {
1797            cb(ReloadHandle {
1798                auth: auth_state.clone(),
1799                rbac: Some(Arc::clone(&rbac_swap)),
1800                crl_set: crl_set.clone(),
1801            });
1802        }
1803
1804        let tls_listener = TlsListener::new(
1805            listener,
1806            &cert_path,
1807            &key_path,
1808            mtls_config.as_ref(),
1809            crl_set,
1810            tls_handshake_timeout,
1811            max_concurrent_tls_handshakes,
1812        )?;
1813        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1814        tokio::select! {
1815            result = axum::serve(tls_listener, make_svc)
1816                .with_graceful_shutdown(graceful) => { result?; }
1817            () = force_exit_timer => {
1818                tracing::warn!("shutdown timeout exceeded, forcing exit");
1819            }
1820        }
1821    } else {
1822        if let Some(cb) = on_reload_ready.take() {
1823            cb(ReloadHandle {
1824                auth: auth_state,
1825                rbac: Some(rbac_swap),
1826                crl_set: None,
1827            });
1828        }
1829
1830        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1831        tokio::select! {
1832            result = axum::serve(listener, make_svc)
1833                .with_graceful_shutdown(graceful) => { result?; }
1834            () = force_exit_timer => {
1835                tracing::warn!("shutdown timeout exceeded, forcing exit");
1836            }
1837        }
1838    }
1839
1840    Ok(())
1841}
1842
1843/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
1844/// `/register`, and authorization server metadata) on `router`. The
1845/// caller must ensure `oauth_config.proxy` is `Some`.
1846///
1847/// # Errors
1848///
1849/// Returns [`McpxError::Startup`] if the shared
1850/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
1851#[cfg(feature = "oauth")]
1852fn install_oauth_proxy_routes(
1853    router: axum::Router,
1854    server_url: &str,
1855    oauth_config: &crate::oauth::OAuthConfig,
1856    auth_state: Option<&Arc<AuthState>>,
1857) -> Result<axum::Router, McpxError> {
1858    let Some(ref proxy) = oauth_config.proxy else {
1859        return Ok(router);
1860    };
1861
1862    // Single shared HTTP client for all proxy endpoints. Cloning is
1863    // cheap (refcounted) and shares the underlying connection pool.
1864    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1865
1866    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1867    let router = router.route(
1868        "/.well-known/oauth-authorization-server",
1869        axum::routing::get(move || {
1870            let m = asm.clone();
1871            async move { axum::Json(m) }
1872        }),
1873    );
1874
1875    let proxy_authorize = proxy.clone();
1876    let router = router.route(
1877        "/authorize",
1878        axum::routing::get(
1879            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1880                let p = proxy_authorize.clone();
1881                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1882            },
1883        ),
1884    );
1885
1886    let proxy_token = proxy.clone();
1887    let token_http = http.clone();
1888    let router = router.route(
1889        "/token",
1890        axum::routing::post(move |body: String| {
1891            let p = proxy_token.clone();
1892            let h = token_http.clone();
1893            async move { crate::oauth::handle_token(&h, &p, &body).await }
1894        })
1895        .layer(axum::middleware::from_fn(
1896            oauth_token_cache_headers_middleware,
1897        )),
1898    );
1899
1900    let proxy_register = proxy.clone();
1901    let router = router.route(
1902        "/register",
1903        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1904            let p = proxy_register;
1905            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1906        })
1907        .layer(axum::middleware::from_fn(
1908            oauth_token_cache_headers_middleware,
1909        )),
1910    );
1911
1912    let admin_routes_enabled = proxy.expose_admin_endpoints
1913        && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1914    if proxy.expose_admin_endpoints
1915        && !proxy.require_auth_on_admin_endpoints
1916        && proxy.allow_unauthenticated_admin_endpoints
1917    {
1918        // M3 escape-hatch in effect: validate() let this through because
1919        // the operator explicitly opted in. Surface it loudly at startup
1920        // so the choice is auditable in logs.
1921        tracing::warn!(
1922            "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1923             allow_unauthenticated_admin_endpoints opt-out; ensure an \
1924             authenticated reverse proxy fronts these routes"
1925        );
1926    }
1927
1928    let admin_router = if admin_routes_enabled {
1929        build_oauth_admin_router(proxy, http, auth_state)?
1930    } else {
1931        axum::Router::new()
1932    };
1933
1934    let router = router.merge(admin_router);
1935
1936    tracing::info!(
1937        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1938        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1939        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1940    );
1941    Ok(router)
1942}
1943
1944/// Build the optional `/introspect` + `/revoke` admin sub-router.
1945///
1946/// Layered with [`oauth_token_cache_headers_middleware`] so RFC 6749 §5.1
1947/// / RFC 6750 §5.4 cache headers are emitted, and conditionally with the
1948/// auth middleware when `proxy.require_auth_on_admin_endpoints` is set.
1949#[cfg(feature = "oauth")]
1950fn build_oauth_admin_router(
1951    proxy: &crate::oauth::OAuthProxyConfig,
1952    http: crate::oauth::OauthHttpClient,
1953    auth_state: Option<&Arc<AuthState>>,
1954) -> Result<axum::Router, McpxError> {
1955    let mut admin_router = axum::Router::new();
1956    if proxy.introspection_url.is_some() {
1957        let proxy_introspect = proxy.clone();
1958        let introspect_http = http.clone();
1959        admin_router = admin_router.route(
1960            "/introspect",
1961            axum::routing::post(move |body: String| {
1962                let p = proxy_introspect.clone();
1963                let h = introspect_http.clone();
1964                async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1965            }),
1966        );
1967    }
1968    if proxy.revocation_url.is_some() {
1969        let proxy_revoke = proxy.clone();
1970        let revoke_http = http;
1971        admin_router = admin_router.route(
1972            "/revoke",
1973            axum::routing::post(move |body: String| {
1974                let p = proxy_revoke.clone();
1975                let h = revoke_http.clone();
1976                async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1977            }),
1978        );
1979    }
1980
1981    let admin_router = admin_router.layer(axum::middleware::from_fn(
1982        oauth_token_cache_headers_middleware,
1983    ));
1984
1985    if proxy.require_auth_on_admin_endpoints {
1986        let Some(state) = auth_state else {
1987            return Err(McpxError::Startup(
1988                "oauth proxy admin endpoints require auth state".into(),
1989            ));
1990        };
1991        let state_for_mw = Arc::clone(state);
1992        Ok(
1993            admin_router.layer(axum::middleware::from_fn(move |req, next| {
1994                let s = Arc::clone(&state_for_mw);
1995                auth_middleware(s, req, next)
1996            })),
1997        )
1998    } else {
1999        Ok(admin_router)
2000    }
2001}
2002
2003/// Build the host allow-list for rmcp's DNS rebinding protection.
2004///
2005/// Includes loopback hosts by default, then augments with host/authority
2006/// derived from `public_url` and the server bind address.
2007fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
2008    let mut hosts = vec![
2009        "localhost".to_owned(),
2010        "127.0.0.1".to_owned(),
2011        "::1".to_owned(),
2012    ];
2013
2014    if let Some(url) = public_url
2015        && let Ok(uri) = url.parse::<axum::http::Uri>()
2016        && let Some(authority) = uri.authority()
2017    {
2018        let host = authority.host().to_owned();
2019        if !hosts.iter().any(|h| h == &host) {
2020            hosts.push(host);
2021        }
2022
2023        let authority = authority.as_str().to_owned();
2024        if !hosts.iter().any(|h| h == &authority) {
2025            hosts.push(authority);
2026        }
2027    }
2028
2029    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
2030        && let Some(authority) = uri.authority()
2031    {
2032        let host = authority.host().to_owned();
2033        if !hosts.iter().any(|h| h == &host) {
2034            hosts.push(host);
2035        }
2036
2037        let authority = authority.as_str().to_owned();
2038        if !hosts.iter().any(|h| h == &authority) {
2039            hosts.push(authority);
2040        }
2041    }
2042
2043    hosts
2044}
2045
2046// - TLS support -
2047
2048/// Implement axum's `Connected` trait for `TlsConnInfo` so that
2049/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
2050/// over our custom `TlsListener`.
2051///
2052/// The identity is read directly from the wrapping
2053/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
2054/// between the TLS connection and its mTLS identity. This eliminates the
2055/// previous shared-map approach which was vulnerable to ephemeral-port
2056/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
2057/// pair could alias a stale entry).
2058impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
2059    for TlsConnInfo
2060{
2061    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
2062        let addr = *target.remote_addr();
2063        let identity = target.io().identity().cloned();
2064        TlsConnInfo::new(addr, identity)
2065    }
2066}
2067
2068/// Default per-handshake deadline on the TLS accept path. Prevents idle
2069/// or slow-loris connections from pinning handshake worker tasks (and
2070/// their semaphore permits) indefinitely.
2071///
2072/// Configurable since 1.9.0 via
2073/// [`McpServerConfig::with_tls_handshake_timeout`].
2074const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
2075
2076/// Default upper bound on concurrently in-flight TLS handshakes. When
2077/// saturated, the acceptor task stops pulling new connections from the
2078/// kernel backlog (backpressure) instead of accepting and dropping them
2079/// in user space.
2080///
2081/// Configurable since 1.9.0 via
2082/// [`McpServerConfig::with_max_concurrent_tls_handshakes`].
2083const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
2084
2085/// Capacity of the completed-handshake queue between the acceptor task and
2086/// `axum::serve`'s `accept()` loop. Handshake workers block on `send` when
2087/// the queue is full, so a slow accept loop back-pressures handshakes
2088/// rather than buffering completed connections unboundedly.
2089const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
2090
2091/// A TLS-wrapping listener that implements axum's `Listener` trait.
2092///
2093/// TCP accepts and TLS handshakes run on a dedicated background task: each
2094/// accepted connection's handshake is spawned onto its own worker task,
2095/// bounded by a configurable concurrent-handshake cap (default
2096/// [`DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES`]) and a per-handshake timeout
2097/// (default [`DEFAULT_TLS_HANDSHAKE_TIMEOUT`]). A slow or idle client
2098/// therefore cannot stall other connections behind a serialized inline
2099/// handshake.
2100///
2101/// When mTLS is configured, client certificates are verified against the
2102/// configured CA and the client identity is extracted at handshake time.
2103/// The extracted identity is bound to the connection itself via the
2104/// returned [`AuthenticatedTlsStream`], so it is impossible for an
2105/// unrelated connection to observe it.
2106struct TlsListener {
2107    /// Bound address, captured eagerly before the `TcpListener` moves into
2108    /// the acceptor task.
2109    local_addr: SocketAddr,
2110    /// Completed handshakes produced by the acceptor task's workers.
2111    rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2112    /// Background task driving TCP accepts and concurrent TLS handshakes.
2113    /// Aborted on drop so the listener releases its port deterministically.
2114    acceptor_task: tokio::task::JoinHandle<()>,
2115}
2116
2117impl TlsListener {
2118    fn new(
2119        inner: TcpListener,
2120        cert_path: &Path,
2121        key_path: &Path,
2122        mtls_config: Option<&MtlsConfig>,
2123        crl_set: Option<Arc<CrlSet>>,
2124        handshake_timeout: Duration,
2125        max_concurrent_handshakes: usize,
2126    ) -> anyhow::Result<Self> {
2127        // Install the ring crypto provider (ok to call multiple times).
2128        rustls::crypto::ring::default_provider()
2129            .install_default()
2130            .ok();
2131
2132        let certs = load_certs(cert_path)?;
2133        let key = load_key(key_path)?;
2134
2135        let mtls_default_role;
2136
2137        let tls_config = if let Some(mtls) = mtls_config {
2138            mtls_default_role = mtls.default_role.clone();
2139            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2140            {
2141                let Some(crl_set) = crl_set else {
2142                    return Err(anyhow::anyhow!(
2143                        "mTLS CRL verifier requested but CRL state was not initialized"
2144                    ));
2145                };
2146                Arc::new(DynamicClientCertVerifier::new(crl_set))
2147            } else {
2148                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2149                if mtls.required {
2150                    rustls::server::WebPkiClientVerifier::builder(root_store)
2151                        .build()
2152                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2153                } else {
2154                    rustls::server::WebPkiClientVerifier::builder(root_store)
2155                        .allow_unauthenticated()
2156                        .build()
2157                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2158                }
2159            };
2160
2161            tracing::info!(
2162                ca = %mtls.ca_cert_path.display(),
2163                required = mtls.required,
2164                crl_enabled = mtls.crl_enabled,
2165                "mTLS client auth configured"
2166            );
2167
2168            rustls::ServerConfig::builder_with_protocol_versions(&[
2169                &rustls::version::TLS12,
2170                &rustls::version::TLS13,
2171            ])
2172            .with_client_cert_verifier(verifier)
2173            .with_single_cert(certs, key)?
2174        } else {
2175            mtls_default_role = "viewer".to_owned();
2176            rustls::ServerConfig::builder_with_protocol_versions(&[
2177                &rustls::version::TLS12,
2178                &rustls::version::TLS13,
2179            ])
2180            .with_no_client_auth()
2181            .with_single_cert(certs, key)?
2182        };
2183
2184        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2185        tracing::info!(
2186            "TLS enabled (cert: {}, key: {})",
2187            cert_path.display(),
2188            key_path.display()
2189        );
2190        let local_addr = inner.local_addr()?;
2191        let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2192        let acceptor_task = tokio::spawn(run_tls_acceptor(
2193            inner,
2194            acceptor,
2195            mtls_default_role,
2196            tx,
2197            handshake_timeout,
2198            max_concurrent_handshakes,
2199        ));
2200        Ok(Self {
2201            local_addr,
2202            rx,
2203            acceptor_task,
2204        })
2205    }
2206
2207    /// Extract the mTLS client cert identity from a completed TLS handshake.
2208    /// Returns `None` if no client certificate was presented or if the
2209    /// certificate could not be parsed into an [`AuthIdentity`].
2210    fn extract_handshake_identity(
2211        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2212        default_role: &str,
2213        addr: SocketAddr,
2214    ) -> Option<AuthIdentity> {
2215        let (_, server_conn) = tls_stream.get_ref();
2216        let cert_der = server_conn.peer_certificates()?.first()?;
2217        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2218        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2219        Some(id)
2220    }
2221}
2222
2223/// Drive TCP accepts and concurrent TLS handshakes for [`TlsListener`].
2224///
2225/// Each accepted connection's handshake runs on its own worker task under
2226/// a permit from a `max_concurrent_handshakes`-sized semaphore and a
2227/// `handshake_timeout` deadline. Completed handshakes are pushed to `tx`;
2228/// failures and timeouts are logged at DEBUG and the connection dropped.
2229/// The loop exits when the owning [`TlsListener`] is dropped.
2230async fn run_tls_acceptor(
2231    listener: TcpListener,
2232    acceptor: tokio_rustls::TlsAcceptor,
2233    default_role: String,
2234    tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2235    handshake_timeout: Duration,
2236    max_concurrent_handshakes: usize,
2237) {
2238    let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2239    loop {
2240        // Acquire the permit BEFORE accepting: at saturation, pending
2241        // connections wait in the kernel backlog instead of being accepted
2242        // and then buffered or dropped in user space.
2243        let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2244            // The semaphore is never closed; defensive exit.
2245            return;
2246        };
2247        let (stream, addr) = match listener.accept().await {
2248            Ok(pair) => pair,
2249            Err(e) => {
2250                tracing::debug!("TCP accept error: {e}");
2251                continue;
2252            }
2253        };
2254        if tx.is_closed() {
2255            // The listener was dropped (shutdown): stop accepting.
2256            return;
2257        }
2258        let acceptor = acceptor.clone();
2259        let default_role = default_role.clone();
2260        let tx = tx.clone();
2261        tokio::spawn(async move {
2262            let _permit = permit;
2263            match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2264                Ok(Ok(tls_stream)) => {
2265                    let identity =
2266                        TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2267                    let wrapped = AuthenticatedTlsStream {
2268                        inner: tls_stream,
2269                        identity,
2270                    };
2271                    // The receiver only disappears during shutdown; discard
2272                    // the completed connection quietly rather than logging.
2273                    let _ = tx.send((wrapped, addr)).await;
2274                }
2275                Ok(Err(e)) => {
2276                    tracing::debug!("TLS handshake failed from {addr}: {e}");
2277                }
2278                Err(_elapsed) => {
2279                    tracing::debug!(
2280                        "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2281                    );
2282                }
2283            }
2284        });
2285    }
2286}
2287
2288/// A TLS stream paired with the mTLS identity extracted at handshake time.
2289///
2290/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
2291/// identity travels with the connection itself. This replaces the previous
2292/// shared `MtlsIdentities` map, eliminating the
2293/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
2294/// reuse and removing the need for an LRU eviction policy.
2295///
2296/// The wrapper is `Unpin` (its inner stream is `Unpin` because
2297/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
2298/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
2299pub(crate) struct AuthenticatedTlsStream {
2300    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2301    identity: Option<AuthIdentity>,
2302}
2303
2304impl AuthenticatedTlsStream {
2305    /// Returns the verified mTLS client identity, if any.
2306    #[must_use]
2307    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2308        self.identity.as_ref()
2309    }
2310}
2311
2312impl std::fmt::Debug for AuthenticatedTlsStream {
2313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2314        f.debug_struct("AuthenticatedTlsStream")
2315            .field("identity", &self.identity.as_ref().map(|id| &id.name))
2316            .finish_non_exhaustive()
2317    }
2318}
2319
2320impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2321    fn poll_read(
2322        mut self: Pin<&mut Self>,
2323        cx: &mut std::task::Context<'_>,
2324        buf: &mut tokio::io::ReadBuf<'_>,
2325    ) -> std::task::Poll<std::io::Result<()>> {
2326        Pin::new(&mut self.inner).poll_read(cx, buf)
2327    }
2328}
2329
2330impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2331    fn poll_write(
2332        mut self: Pin<&mut Self>,
2333        cx: &mut std::task::Context<'_>,
2334        buf: &[u8],
2335    ) -> std::task::Poll<std::io::Result<usize>> {
2336        Pin::new(&mut self.inner).poll_write(cx, buf)
2337    }
2338
2339    fn poll_flush(
2340        mut self: Pin<&mut Self>,
2341        cx: &mut std::task::Context<'_>,
2342    ) -> std::task::Poll<std::io::Result<()>> {
2343        Pin::new(&mut self.inner).poll_flush(cx)
2344    }
2345
2346    fn poll_shutdown(
2347        mut self: Pin<&mut Self>,
2348        cx: &mut std::task::Context<'_>,
2349    ) -> std::task::Poll<std::io::Result<()>> {
2350        Pin::new(&mut self.inner).poll_shutdown(cx)
2351    }
2352
2353    fn poll_write_vectored(
2354        mut self: Pin<&mut Self>,
2355        cx: &mut std::task::Context<'_>,
2356        bufs: &[std::io::IoSlice<'_>],
2357    ) -> std::task::Poll<std::io::Result<usize>> {
2358        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2359    }
2360
2361    fn is_write_vectored(&self) -> bool {
2362        self.inner.is_write_vectored()
2363    }
2364}
2365
2366impl axum::serve::Listener for TlsListener {
2367    type Io = AuthenticatedTlsStream;
2368    type Addr = SocketAddr;
2369
2370    /// Yield the next fully-handshaken TLS connection.
2371    ///
2372    /// Cancel-safe: this is a plain `mpsc::Receiver::recv`, so cancelling
2373    /// the future (axum selects it against graceful shutdown) never loses
2374    /// a connection.
2375    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2376        if let Some(pair) = self.rx.recv().await {
2377            return pair;
2378        }
2379        // The channel only closes if the acceptor task terminated, which
2380        // means the TcpListener is gone and the OS already refuses new
2381        // connections. `Listener::accept` is infallible and panicking is
2382        // forbidden, so park forever: existing connections keep being
2383        // served and graceful shutdown still completes.
2384        tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2385        std::future::pending().await
2386    }
2387
2388    fn local_addr(&self) -> std::io::Result<Self::Addr> {
2389        Ok(self.local_addr)
2390    }
2391}
2392
2393impl Drop for TlsListener {
2394    fn drop(&mut self) {
2395        // Stop accepting immediately and release the bound port. In-flight
2396        // handshake workers notice the closed channel and exit quietly.
2397        self.acceptor_task.abort();
2398    }
2399}
2400
2401fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2402    use rustls::pki_types::pem::PemObject;
2403    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2404        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2405        .collect::<Result<_, _>>()
2406        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2407    anyhow::ensure!(
2408        !certs.is_empty(),
2409        "no certificates found in {}",
2410        path.display()
2411    );
2412    Ok(certs)
2413}
2414
2415fn load_client_auth_roots(
2416    path: &Path,
2417) -> anyhow::Result<(
2418    Vec<rustls::pki_types::CertificateDer<'static>>,
2419    Arc<RootCertStore>,
2420)> {
2421    let ca_certs = load_certs(path)?;
2422    let mut root_store = RootCertStore::empty();
2423    for cert in &ca_certs {
2424        root_store
2425            .add(cert.clone())
2426            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2427    }
2428
2429    Ok((ca_certs, Arc::new(root_store)))
2430}
2431
2432fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2433    use rustls::pki_types::pem::PemObject;
2434    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2435        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2436}
2437
2438#[allow(
2439    clippy::unused_async,
2440    reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2441)]
2442async fn healthz() -> impl IntoResponse {
2443    axum::Json(serde_json::json!({
2444        "status": "ok",
2445    }))
2446}
2447
2448/// Build the `/version` JSON payload for a given server name and version.
2449///
2450/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
2451/// read at compile time from the `RMCP_SERVER_KIT_BUILD_SHA`,
2452/// `RMCP_SERVER_KIT_BUILD_TIME`, and `RMCP_SERVER_KIT_RUSTC_VERSION` env
2453/// vars. Unset values resolve to `"unknown"`.
2454fn version_payload(name: &str, version: &str) -> serde_json::Value {
2455    serde_json::json!({
2456        "name": name,
2457        "version": version,
2458        "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2459        "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2460        "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2461        "mcpx_version": env!("CARGO_PKG_VERSION"),
2462    })
2463}
2464
2465/// Pre-serialize the `/version` payload to immutable bytes.
2466///
2467/// This is called once at router-build time so per-request handling can
2468/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
2469/// [`serde_json::Value`] on every hit.
2470///
2471/// Serialization of a flat `serde_json::Value` of static-string fields
2472/// cannot fail in practice; the fallback to `b"{}"` exists only to
2473/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
2474fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2475    let value = version_payload(name, version);
2476    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2477}
2478
2479async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2480    let status = check().await;
2481    let ready = status
2482        .get("ready")
2483        .and_then(serde_json::Value::as_bool)
2484        .unwrap_or(false);
2485    let code = if ready {
2486        axum::http::StatusCode::OK
2487    } else {
2488        axum::http::StatusCode::SERVICE_UNAVAILABLE
2489    };
2490    (code, axum::Json(status))
2491}
2492
2493/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2494///
2495/// On non-Unix platforms, only SIGINT is handled.
2496async fn shutdown_signal() {
2497    let ctrl_c = tokio::signal::ctrl_c();
2498
2499    #[cfg(unix)]
2500    {
2501        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2502            Ok(mut term) => {
2503                tokio::select! {
2504                    _ = ctrl_c => {}
2505                    _ = term.recv() => {}
2506                }
2507            }
2508            Err(e) => {
2509                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2510                ctrl_c.await.ok();
2511            }
2512        }
2513    }
2514
2515    #[cfg(not(unix))]
2516    {
2517        ctrl_c.await.ok();
2518    }
2519}
2520
2521// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2522
2523/// Middleware that validates the `Origin` header on incoming HTTP requests.
2524///
2525/// Record HTTP request metrics (method, path, status, duration).
2526#[cfg(feature = "metrics")]
2527async fn metrics_middleware(
2528    metrics: Arc<crate::metrics::McpMetrics>,
2529    req: Request<Body>,
2530    next: Next,
2531) -> axum::response::Response {
2532    let method = req.method().to_string();
2533    let path = req.uri().path().to_owned();
2534    let start = std::time::Instant::now();
2535
2536    let response = next.run(req).await;
2537
2538    let status = response.status().as_u16().to_string();
2539    let duration = start.elapsed().as_secs_f64();
2540
2541    metrics
2542        .http_requests_total
2543        .with_label_values(&[&method, &path, &status])
2544        .inc();
2545    metrics
2546        .http_request_duration_seconds
2547        .with_label_values(&[&method, &path])
2548        .observe(duration);
2549
2550    response
2551}
2552
2553/// OWASP security header hardening applied to every response.
2554///
2555/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2556/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2557/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2558/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2559/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2560///
2561/// Each header's value can be customised via [`SecurityHeadersConfig`]
2562/// on [`McpServerConfig`]. See that type for the three-state semantic
2563/// (`None` = default, `Some("")` = omit, `Some(v)` = override).
2564async fn security_headers_middleware(
2565    is_tls: bool,
2566    cfg: Arc<SecurityHeadersConfig>,
2567    req: Request<Body>,
2568    next: Next,
2569) -> axum::response::Response {
2570    use axum::http::{HeaderName, header};
2571
2572    let mut resp = next.run(req).await;
2573    let headers = resp.headers_mut();
2574
2575    // Strip server identity headers to reduce information leakage.
2576    headers.remove(header::SERVER);
2577    headers.remove(HeaderName::from_static("x-powered-by"));
2578
2579    apply_security_header(
2580        headers,
2581        header::X_CONTENT_TYPE_OPTIONS,
2582        cfg.x_content_type_options.as_deref(),
2583        "nosniff",
2584    );
2585    apply_security_header(
2586        headers,
2587        header::X_FRAME_OPTIONS,
2588        cfg.x_frame_options.as_deref(),
2589        "deny",
2590    );
2591    apply_security_header(
2592        headers,
2593        header::CACHE_CONTROL,
2594        cfg.cache_control.as_deref(),
2595        "no-store, max-age=0",
2596    );
2597    apply_security_header(
2598        headers,
2599        header::REFERRER_POLICY,
2600        cfg.referrer_policy.as_deref(),
2601        "no-referrer",
2602    );
2603    apply_security_header(
2604        headers,
2605        HeaderName::from_static("cross-origin-opener-policy"),
2606        cfg.cross_origin_opener_policy.as_deref(),
2607        "same-origin",
2608    );
2609    apply_security_header(
2610        headers,
2611        HeaderName::from_static("cross-origin-resource-policy"),
2612        cfg.cross_origin_resource_policy.as_deref(),
2613        "same-origin",
2614    );
2615    apply_security_header(
2616        headers,
2617        HeaderName::from_static("cross-origin-embedder-policy"),
2618        cfg.cross_origin_embedder_policy.as_deref(),
2619        "require-corp",
2620    );
2621    apply_security_header(
2622        headers,
2623        HeaderName::from_static("permissions-policy"),
2624        cfg.permissions_policy.as_deref(),
2625        "accelerometer=(), camera=(), geolocation=(), microphone=()",
2626    );
2627    apply_security_header(
2628        headers,
2629        HeaderName::from_static("x-permitted-cross-domain-policies"),
2630        cfg.x_permitted_cross_domain_policies.as_deref(),
2631        "none",
2632    );
2633    apply_security_header(
2634        headers,
2635        HeaderName::from_static("content-security-policy"),
2636        cfg.content_security_policy.as_deref(),
2637        "default-src 'none'; frame-ancestors 'none'",
2638    );
2639    apply_security_header(
2640        headers,
2641        HeaderName::from_static("x-dns-prefetch-control"),
2642        cfg.x_dns_prefetch_control.as_deref(),
2643        "off",
2644    );
2645
2646    if is_tls {
2647        apply_security_header(
2648            headers,
2649            header::STRICT_TRANSPORT_SECURITY,
2650            cfg.strict_transport_security.as_deref(),
2651            "max-age=63072000; includeSubDomains",
2652        );
2653    }
2654
2655    resp
2656}
2657
2658/// Set a single security header on the response, honouring the
2659/// three-state override semantic (None = default, Some("") = omit,
2660/// Some(value) = override).
2661///
2662/// Defence-in-depth: if an override value somehow reaches this point
2663/// despite [`validate_security_headers`] having approved it (e.g. a
2664/// runtime mutation on a non-`Validated` field), we log at error level
2665/// and fall back to the static default rather than panicking. The
2666/// `Validated<McpServerConfig>` type makes that path unreachable in
2667/// well-typed code paths.
2668fn apply_security_header(
2669    headers: &mut axum::http::HeaderMap,
2670    name: axum::http::HeaderName,
2671    override_value: Option<&str>,
2672    default: &'static str,
2673) {
2674    use axum::http::HeaderValue;
2675
2676    match override_value {
2677        None => {
2678            headers.insert(name, HeaderValue::from_static(default));
2679        }
2680        Some("") => {
2681            // Operator explicitly opted out of this header.
2682        }
2683        Some(v) => match HeaderValue::from_str(v) {
2684            Ok(hv) => {
2685                headers.insert(name, hv);
2686            }
2687            Err(err) => {
2688                tracing::error!(
2689                    header = %name,
2690                    error = %err,
2691                    "invalid security header override reached middleware; using default"
2692                );
2693                headers.insert(name, HeaderValue::from_static(default));
2694            }
2695        },
2696    }
2697}
2698
2699/// Validate every non-empty entry in a [`SecurityHeadersConfig`].
2700///
2701/// - `None` and `Some("")` are accepted unconditionally (use-default and
2702///   omit, respectively).
2703/// - `Some(v)` is rejected if `axum::http::HeaderValue::from_str(v)` fails.
2704/// - `strict_transport_security` additionally rejects any value
2705///   containing `preload` (case-insensitive). Operators who genuinely
2706///   want to commit to the HSTS preload list must do so via a future
2707///   explicit `with_hsts_preload(true)` builder, not by smuggling
2708///   `preload` through this knob.
2709fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2710    use axum::http::HeaderValue;
2711
2712    let fields: &[(&str, Option<&str>)] = &[
2713        (
2714            "x_content_type_options",
2715            cfg.x_content_type_options.as_deref(),
2716        ),
2717        ("x_frame_options", cfg.x_frame_options.as_deref()),
2718        ("cache_control", cfg.cache_control.as_deref()),
2719        ("referrer_policy", cfg.referrer_policy.as_deref()),
2720        (
2721            "cross_origin_opener_policy",
2722            cfg.cross_origin_opener_policy.as_deref(),
2723        ),
2724        (
2725            "cross_origin_resource_policy",
2726            cfg.cross_origin_resource_policy.as_deref(),
2727        ),
2728        (
2729            "cross_origin_embedder_policy",
2730            cfg.cross_origin_embedder_policy.as_deref(),
2731        ),
2732        ("permissions_policy", cfg.permissions_policy.as_deref()),
2733        (
2734            "x_permitted_cross_domain_policies",
2735            cfg.x_permitted_cross_domain_policies.as_deref(),
2736        ),
2737        (
2738            "content_security_policy",
2739            cfg.content_security_policy.as_deref(),
2740        ),
2741        (
2742            "x_dns_prefetch_control",
2743            cfg.x_dns_prefetch_control.as_deref(),
2744        ),
2745        (
2746            "strict_transport_security",
2747            cfg.strict_transport_security.as_deref(),
2748        ),
2749    ];
2750
2751    for (field, value) in fields {
2752        let Some(v) = value else { continue };
2753        if v.is_empty() {
2754            continue;
2755        }
2756        if let Err(err) = HeaderValue::from_str(v) {
2757            return Err(McpxError::Config(format!(
2758                "invalid security_headers.{field}: {err}"
2759            )));
2760        }
2761    }
2762
2763    if let Some(v) = cfg.strict_transport_security.as_deref()
2764        && !v.is_empty()
2765        && v.to_ascii_lowercase().contains("preload")
2766    {
2767        return Err(McpxError::Config(format!(
2768            "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2769             HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2770        )));
2771    }
2772
2773    Ok(())
2774}
2775
2776/// Append RFC 6749 §5.1 / RFC 6750 §5.4 cache and `Vary` headers required
2777/// on OAuth token-issuing responses.
2778///
2779/// `Cache-Control: no-store, max-age=0` is already applied globally by
2780/// [`security_headers_middleware`]; this middleware adds:
2781///
2782/// - `Pragma: no-cache` -- mandated by RFC 6749 §5.1 for HTTP/1.0 caches.
2783/// - `Vary: Authorization` -- mandated by RFC 6750 §5.4 for endpoints
2784///   whose response depends on the `Authorization` header.
2785///
2786/// Applied only to the OAuth proxy token-class endpoints (`/token`,
2787/// `/register`, `/introspect`, `/revoke`). `Vary` is appended (not
2788/// inserted) so any `Vary` value already present (e.g. `Accept-Encoding`
2789/// from a compression layer, or `Origin` from a CORS layer) is preserved.
2790#[cfg(feature = "oauth")]
2791async fn oauth_token_cache_headers_middleware(
2792    req: Request<Body>,
2793    next: Next,
2794) -> axum::response::Response {
2795    use axum::http::{HeaderValue, header};
2796
2797    let mut resp = next.run(req).await;
2798    let headers = resp.headers_mut();
2799    headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2800    headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2801    resp
2802}
2803
2804/// Normalize peer-address request extensions across listener branches.
2805///
2806/// The make-service installs `ConnectInfo<SocketAddr>` on the plain
2807/// listener but `ConnectInfo<TlsConnInfo>` on the TLS listener (the
2808/// latter additionally carries the connection-bound mTLS identity and
2809/// stays `pub(crate)` — see the anti-aliasing rationale on
2810/// [`TlsConnInfo`]). Application routes — in particular those merged via
2811/// [`McpServerConfig::with_extra_router`], which bypass the auth
2812/// middleware and its private fallback — could therefore not read the
2813/// peer address under TLS.
2814///
2815/// This middleware makes both branches look identical to every route and
2816/// inner middleware:
2817///
2818/// 1. mirrors the TLS peer address into `ConnectInfo<SocketAddr>` when
2819///    (and only when) it is absent, so stock axum-ecosystem extractors
2820///    work unmodified, and
2821/// 2. inserts the framework-owned [`PeerAddr`] extension on both
2822///    branches.
2823///
2824/// Precedence mirrors the auth middleware: an existing
2825/// `ConnectInfo<SocketAddr>` always wins and is never overwritten. The
2826/// peer address is deliberately not logged here.
2827async fn normalize_peer_addr_middleware(
2828    mut req: Request<Body>,
2829    next: Next,
2830) -> axum::response::Response {
2831    let direct = req
2832        .extensions()
2833        .get::<ConnectInfo<SocketAddr>>()
2834        .map(|ci| ci.0);
2835    let from_tls = req
2836        .extensions()
2837        .get::<ConnectInfo<TlsConnInfo>>()
2838        .map(|ci| ci.0.addr);
2839    if let Some(addr) = direct.or(from_tls) {
2840        if direct.is_none() {
2841            req.extensions_mut().insert(ConnectInfo(addr));
2842        }
2843        req.extensions_mut().insert(PeerAddr::new(addr));
2844    }
2845    next.run(req).await
2846}
2847
2848/// Per the MCP spec: if the Origin header is present and its value is not in
2849/// the allowed list, respond with 403 Forbidden. Requests without an Origin
2850/// header are allowed through (e.g. non-browser clients like curl, SDKs).
2851async fn origin_check_middleware(
2852    allowed: Arc<[String]>,
2853    log_request_headers: bool,
2854    req: Request<Body>,
2855    next: Next,
2856) -> axum::response::Response {
2857    let method = req.method().clone();
2858    let path = req.uri().path().to_owned();
2859
2860    log_incoming_request(&method, &path, req.headers(), log_request_headers);
2861
2862    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2863        let origin_str = origin.to_str().unwrap_or("");
2864        if !allowed.iter().any(|a| a == origin_str) {
2865            tracing::warn!(
2866                origin = origin_str,
2867                %method,
2868                %path,
2869                allowed = ?&*allowed,
2870                "rejected request: Origin not allowed"
2871            );
2872            return (
2873                axum::http::StatusCode::FORBIDDEN,
2874                "Forbidden: Origin not allowed",
2875            )
2876                .into_response();
2877        }
2878    }
2879    next.run(req).await
2880}
2881
2882/// Emit a DEBUG log for an incoming request, optionally including the full
2883/// (redacted) header set.
2884fn log_incoming_request(
2885    method: &axum::http::Method,
2886    path: &str,
2887    headers: &axum::http::HeaderMap,
2888    log_request_headers: bool,
2889) {
2890    if log_request_headers {
2891        tracing::debug!(
2892            %method,
2893            %path,
2894            headers = %format_request_headers_for_log(headers),
2895            "incoming request"
2896        );
2897    } else {
2898        tracing::debug!(%method, %path, "incoming request");
2899    }
2900}
2901
2902fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2903    headers
2904        .iter()
2905        .map(|(k, v)| {
2906            let name = k.as_str();
2907            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2908                format!("{name}: [REDACTED]")
2909            } else {
2910                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2911            }
2912        })
2913        .collect::<Vec<_>>()
2914        .join(", ")
2915}
2916
2917// -- stdio transport --
2918
2919/// Serve an MCP server over stdin/stdout (stdio transport).
2920///
2921/// # Security warnings
2922///
2923/// - **No authentication**: the parent process has full, unrestricted access.
2924/// - **No RBAC**: all tools are available regardless of policy.
2925/// - **No TLS**: messages travel over OS pipes in plaintext.
2926/// - **Single client**: only the parent process can connect.
2927/// - **No Origin validation**: not applicable to stdio.
2928///
2929/// Use this only when the MCP client spawns the server as a trusted subprocess
2930/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
2931/// use `serve()` (Streamable HTTP) instead.
2932///
2933/// # Errors
2934///
2935/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
2936/// transport disconnects unexpectedly.
2937// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
2938// macro expansion in this 18-line function (info/warn/info + two matches).
2939// There is nothing meaningful to extract; the allow stays.
2940#[allow(
2941    clippy::cognitive_complexity,
2942    reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
2943)]
2944pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2945where
2946    H: ServerHandler + 'static,
2947{
2948    use rmcp::ServiceExt as _;
2949
2950    tracing::info!("stdio transport: serving on stdin/stdout");
2951    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2952
2953    let transport = rmcp::transport::io::stdio();
2954
2955    let service = handler
2956        .serve(transport)
2957        .await
2958        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2959
2960    if let Err(e) = service.waiting().await {
2961        tracing::warn!(error = %e, "stdio session ended with error");
2962    }
2963    tracing::info!("stdio session ended");
2964    Ok(())
2965}
2966
2967#[cfg(test)]
2968mod tests {
2969    #![allow(
2970        clippy::unwrap_used,
2971        clippy::expect_used,
2972        clippy::panic,
2973        clippy::indexing_slicing,
2974        clippy::unwrap_in_result,
2975        clippy::print_stdout,
2976        clippy::print_stderr,
2977        deprecated,
2978        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2979    )]
2980    use std::{sync::Arc, time::Duration};
2981
2982    use axum::{
2983        body::Body,
2984        http::{Request, StatusCode, header},
2985        response::IntoResponse,
2986    };
2987    use http_body_util::BodyExt;
2988    use tower::ServiceExt as _;
2989
2990    use super::*;
2991
2992    // -- McpServerConfig --
2993
2994    #[test]
2995    fn server_config_new_defaults() {
2996        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2997        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2998        assert_eq!(cfg.name, "test-server");
2999        assert_eq!(cfg.version, "1.0.0");
3000        assert!(cfg.tls_cert_path.is_none());
3001        assert!(cfg.tls_key_path.is_none());
3002        assert!(cfg.auth.is_none());
3003        assert!(cfg.rbac.is_none());
3004        assert!(cfg.allowed_origins.is_empty());
3005        assert!(cfg.tool_rate_limit.is_none());
3006        assert!(cfg.readiness_check.is_none());
3007        assert_eq!(cfg.max_request_body, 1024 * 1024);
3008        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
3009        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
3010        assert!(!cfg.log_request_headers);
3011        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
3012        assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
3013    }
3014
3015    #[test]
3016    fn tls_handshake_builders_set_fields() {
3017        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3018            .with_tls_handshake_timeout(Duration::from_secs(3))
3019            .with_max_concurrent_tls_handshakes(64);
3020        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
3021        assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
3022    }
3023
3024    #[test]
3025    fn validate_rejects_zero_tls_handshake_timeout() {
3026        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3027            .with_tls_handshake_timeout(Duration::ZERO);
3028        let err = cfg.validate().expect_err("zero handshake timeout");
3029        assert!(err.to_string().contains("tls_handshake_timeout"));
3030    }
3031
3032    #[test]
3033    fn validate_rejects_zero_max_concurrent_tls_handshakes() {
3034        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3035            .with_max_concurrent_tls_handshakes(0);
3036        let err = cfg.validate().expect_err("zero handshake concurrency");
3037        assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
3038    }
3039
3040    #[test]
3041    fn validate_consumes_and_proves() {
3042        // Valid config -> Validated wrapper, original is consumed.
3043        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3044        let validated = cfg.validate().expect("valid config");
3045        // as_inner() gives read-only access to inner fields.
3046        assert_eq!(validated.as_inner().name, "test-server");
3047        // into_inner recovers the raw value.
3048        let raw = validated.into_inner();
3049        assert_eq!(raw.name, "test-server");
3050
3051        // Invalid config (zero max_request_body) -> Err.
3052        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3053        bad.max_request_body = 0;
3054        assert!(bad.validate().is_err(), "zero body cap must fail validate");
3055    }
3056
3057    #[test]
3058    fn validate_rejects_zero_max_concurrent_requests() {
3059        let cfg =
3060            McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
3061        let err = cfg.validate().expect_err("zero concurrency cap must fail");
3062        assert!(
3063            format!("{err}").contains("max_concurrent_requests"),
3064            "error should mention max_concurrent_requests, got: {err}"
3065        );
3066    }
3067
3068    #[test]
3069    fn validate_rejects_zero_max_tracked_keys() {
3070        // Defaults mirror auth::default_max_attempts / default_idle_eviction
3071        // (module-private in auth.rs); spelled out here for review clarity.
3072        let rl = crate::auth::RateLimitConfig {
3073            max_attempts_per_minute: 30,
3074            pre_auth_max_per_minute: None,
3075            max_tracked_keys: 0,
3076            idle_eviction: Duration::from_secs(15 * 60),
3077        };
3078        let auth_cfg = AuthConfig {
3079            enabled: true,
3080            api_keys: Vec::new(),
3081            mtls: None,
3082            rate_limit: Some(rl),
3083            #[cfg(feature = "oauth")]
3084            oauth: None,
3085        };
3086        let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
3087        let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
3088        assert!(
3089            format!("{err}").contains("max_tracked_keys"),
3090            "error should mention max_tracked_keys, got: {err}"
3091        );
3092    }
3093
3094    #[test]
3095    fn derive_allowed_hosts_includes_public_host() {
3096        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
3097        assert!(
3098            hosts.iter().any(|h| h == "mcp.example.com"),
3099            "public_url host must be allowed"
3100        );
3101    }
3102
3103    #[test]
3104    fn derive_allowed_hosts_includes_bind_authority() {
3105        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
3106        assert!(
3107            hosts.iter().any(|h| h == "127.0.0.1"),
3108            "bind host must be allowed"
3109        );
3110        assert!(
3111            hosts.iter().any(|h| h == "127.0.0.1:8080"),
3112            "bind authority must be allowed"
3113        );
3114    }
3115
3116    // -- healthz --
3117
3118    #[tokio::test]
3119    async fn healthz_returns_ok_json() {
3120        let resp = healthz().await.into_response();
3121        assert_eq!(resp.status(), StatusCode::OK);
3122        let body = resp.into_body().collect().await.unwrap().to_bytes();
3123        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3124        assert_eq!(json["status"], "ok");
3125        assert!(
3126            json.get("name").is_none(),
3127            "healthz must not expose server name"
3128        );
3129        assert!(
3130            json.get("version").is_none(),
3131            "healthz must not expose version"
3132        );
3133    }
3134
3135    // -- readyz --
3136
3137    #[tokio::test]
3138    async fn readyz_returns_ok_when_ready() {
3139        let check: ReadinessCheck =
3140            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
3141        let resp = readyz(check).await.into_response();
3142        assert_eq!(resp.status(), StatusCode::OK);
3143        let body = resp.into_body().collect().await.unwrap().to_bytes();
3144        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3145        assert_eq!(json["ready"], true);
3146        assert!(
3147            json.get("name").is_none(),
3148            "readyz must not expose server name"
3149        );
3150        assert!(
3151            json.get("version").is_none(),
3152            "readyz must not expose version"
3153        );
3154        assert_eq!(json["db"], "connected");
3155    }
3156
3157    #[tokio::test]
3158    async fn readyz_returns_503_when_not_ready() {
3159        let check: ReadinessCheck =
3160            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3161        let resp = readyz(check).await.into_response();
3162        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3163    }
3164
3165    #[tokio::test]
3166    async fn readyz_returns_503_when_ready_missing() {
3167        let check: ReadinessCheck =
3168            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3169        let resp = readyz(check).await.into_response();
3170        // Missing "ready" field defaults to false -> 503
3171        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3172    }
3173
3174    // -- normalize_peer_addr_middleware / PeerAddr --
3175
3176    /// Build a test router that reports the request's peer-address
3177    /// extensions as `"<ConnectInfo>|<PeerAddr>"` (empty when absent).
3178    fn peer_probe_router() -> axum::Router {
3179        async fn probe(req: Request<Body>) -> String {
3180            let ci = req
3181                .extensions()
3182                .get::<ConnectInfo<SocketAddr>>()
3183                .map(|c| c.0.to_string())
3184                .unwrap_or_default();
3185            let pa = req
3186                .extensions()
3187                .get::<PeerAddr>()
3188                .map(|p| p.addr.to_string())
3189                .unwrap_or_default();
3190            format!("{ci}|{pa}")
3191        }
3192        axum::Router::new()
3193            .route("/probe", axum::routing::get(probe))
3194            .layer(axum::middleware::from_fn(normalize_peer_addr_middleware))
3195    }
3196
3197    async fn body_string(resp: axum::response::Response) -> String {
3198        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
3199        String::from_utf8(bytes.to_vec()).unwrap()
3200    }
3201
3202    #[tokio::test]
3203    async fn normalize_preserves_existing_connect_info_and_mirrors_peer_addr() {
3204        // Precedence proof: when both extensions exist with DIFFERENT
3205        // addresses, ConnectInfo<SocketAddr> wins and is never overwritten.
3206        let plain: SocketAddr = "10.0.0.1:1111".parse().unwrap();
3207        let tls: SocketAddr = "10.0.0.2:2222".parse().unwrap();
3208        let req = Request::builder()
3209            .uri("/probe")
3210            .extension(ConnectInfo(plain))
3211            .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3212            .body(Body::empty())
3213            .unwrap();
3214        let resp = peer_probe_router().oneshot(req).await.unwrap();
3215        assert_eq!(resp.status(), StatusCode::OK);
3216        assert_eq!(body_string(resp).await, format!("{plain}|{plain}"));
3217    }
3218
3219    #[tokio::test]
3220    async fn normalize_inserts_connect_info_and_peer_addr_from_tls() {
3221        let tls: SocketAddr = "192.168.1.7:50443".parse().unwrap();
3222        let req = Request::builder()
3223            .uri("/probe")
3224            .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3225            .body(Body::empty())
3226            .unwrap();
3227        let resp = peer_probe_router().oneshot(req).await.unwrap();
3228        assert_eq!(resp.status(), StatusCode::OK);
3229        assert_eq!(body_string(resp).await, format!("{tls}|{tls}"));
3230    }
3231
3232    #[tokio::test]
3233    async fn normalize_no_op_without_any_connect_info() {
3234        let req = Request::builder()
3235            .uri("/probe")
3236            .body(Body::empty())
3237            .unwrap();
3238        let resp = peer_probe_router().oneshot(req).await.unwrap();
3239        assert_eq!(resp.status(), StatusCode::OK);
3240        assert_eq!(body_string(resp).await, "|");
3241    }
3242
3243    #[tokio::test]
3244    async fn peer_addr_extractor_rejects_when_absent() {
3245        async fn h(peer: PeerAddr) -> String {
3246            peer.addr.to_string()
3247        }
3248        let app = axum::Router::new().route("/p", axum::routing::get(h));
3249        let req = Request::builder().uri("/p").body(Body::empty()).unwrap();
3250        let resp = app.oneshot(req).await.unwrap();
3251        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
3252    }
3253
3254    #[tokio::test]
3255    async fn peer_addr_extractor_returns_value_when_present() {
3256        async fn h(peer: PeerAddr) -> String {
3257            peer.addr.to_string()
3258        }
3259        let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
3260        let app = axum::Router::new().route("/p", axum::routing::get(h));
3261        let req = Request::builder()
3262            .uri("/p")
3263            .extension(PeerAddr::new(addr))
3264            .body(Body::empty())
3265            .unwrap();
3266        let resp = app.oneshot(req).await.unwrap();
3267        assert_eq!(resp.status(), StatusCode::OK);
3268        assert_eq!(body_string(resp).await, addr.to_string());
3269    }
3270
3271    #[tokio::test]
3272    async fn peer_addr_via_extension_extractor() {
3273        async fn h(axum::Extension(peer): axum::Extension<PeerAddr>) -> String {
3274            peer.addr.to_string()
3275        }
3276        let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
3277        let app = axum::Router::new().route("/p", axum::routing::get(h));
3278        let req = Request::builder()
3279            .uri("/p")
3280            .extension(PeerAddr::new(addr))
3281            .body(Body::empty())
3282            .unwrap();
3283        let resp = app.oneshot(req).await.unwrap();
3284        assert_eq!(resp.status(), StatusCode::OK);
3285        assert_eq!(body_string(resp).await, addr.to_string());
3286    }
3287
3288    // -- origin_check_middleware --
3289
3290    /// Build a test router with origin check middleware and a simple handler.
3291    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
3292        let allowed: Arc<[String]> = Arc::from(origins);
3293        axum::Router::new()
3294            .route("/test", axum::routing::get(|| async { "ok" }))
3295            .layer(axum::middleware::from_fn(move |req, next| {
3296                let a = Arc::clone(&allowed);
3297                origin_check_middleware(a, log_request_headers, req, next)
3298            }))
3299    }
3300
3301    #[tokio::test]
3302    async fn origin_allowed_passes() {
3303        let app = origin_router(vec!["http://localhost:3000".into()], false);
3304        let req = Request::builder()
3305            .uri("/test")
3306            .header(header::ORIGIN, "http://localhost:3000")
3307            .body(Body::empty())
3308            .unwrap();
3309        let resp = app.oneshot(req).await.unwrap();
3310        assert_eq!(resp.status(), StatusCode::OK);
3311    }
3312
3313    #[tokio::test]
3314    async fn origin_rejected_returns_403() {
3315        let app = origin_router(vec!["http://localhost:3000".into()], false);
3316        let req = Request::builder()
3317            .uri("/test")
3318            .header(header::ORIGIN, "http://evil.com")
3319            .body(Body::empty())
3320            .unwrap();
3321        let resp = app.oneshot(req).await.unwrap();
3322        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3323    }
3324
3325    #[tokio::test]
3326    async fn no_origin_header_passes() {
3327        let app = origin_router(vec!["http://localhost:3000".into()], false);
3328        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3329        let resp = app.oneshot(req).await.unwrap();
3330        assert_eq!(resp.status(), StatusCode::OK);
3331    }
3332
3333    #[tokio::test]
3334    async fn empty_allowlist_rejects_any_origin() {
3335        let app = origin_router(vec![], false);
3336        let req = Request::builder()
3337            .uri("/test")
3338            .header(header::ORIGIN, "http://anything.com")
3339            .body(Body::empty())
3340            .unwrap();
3341        let resp = app.oneshot(req).await.unwrap();
3342        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3343    }
3344
3345    #[tokio::test]
3346    async fn empty_allowlist_passes_without_origin() {
3347        let app = origin_router(vec![], false);
3348        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3349        let resp = app.oneshot(req).await.unwrap();
3350        assert_eq!(resp.status(), StatusCode::OK);
3351    }
3352
3353    #[test]
3354    fn format_request_headers_redacts_sensitive_values() {
3355        let mut headers = axum::http::HeaderMap::new();
3356        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
3357        headers.insert("cookie", "sid=abc".parse().unwrap());
3358        headers.insert("x-request-id", "req-123".parse().unwrap());
3359
3360        let out = format_request_headers_for_log(&headers);
3361        assert!(out.contains("authorization: [REDACTED]"));
3362        assert!(out.contains("cookie: [REDACTED]"));
3363        assert!(out.contains("x-request-id: req-123"));
3364        assert!(!out.contains("secret-token"));
3365    }
3366
3367    // -- security_headers_middleware --
3368
3369    fn security_router(is_tls: bool) -> axum::Router {
3370        security_router_with(is_tls, SecurityHeadersConfig::default())
3371    }
3372
3373    fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
3374        let cfg = Arc::new(cfg);
3375        axum::Router::new()
3376            .route("/test", axum::routing::get(|| async { "ok" }))
3377            .layer(axum::middleware::from_fn(move |req, next| {
3378                let c = Arc::clone(&cfg);
3379                security_headers_middleware(is_tls, c, req, next)
3380            }))
3381    }
3382
3383    #[tokio::test]
3384    async fn security_headers_set_on_response() {
3385        let app = security_router(false);
3386        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3387        let resp = app.oneshot(req).await.unwrap();
3388        assert_eq!(resp.status(), StatusCode::OK);
3389
3390        let h = resp.headers();
3391        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
3392        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
3393        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
3394        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
3395        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
3396        assert_eq!(
3397            h.get("cross-origin-resource-policy").unwrap(),
3398            "same-origin"
3399        );
3400        assert_eq!(
3401            h.get("cross-origin-embedder-policy").unwrap(),
3402            "require-corp"
3403        );
3404        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
3405        assert!(
3406            h.get("permissions-policy")
3407                .unwrap()
3408                .to_str()
3409                .unwrap()
3410                .contains("camera=()"),
3411            "permissions-policy must restrict browser features"
3412        );
3413        assert_eq!(
3414            h.get("content-security-policy").unwrap(),
3415            "default-src 'none'; frame-ancestors 'none'"
3416        );
3417        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
3418        // No HSTS when TLS is off.
3419        assert!(h.get("strict-transport-security").is_none());
3420    }
3421
3422    #[tokio::test]
3423    async fn hsts_set_when_tls_enabled() {
3424        let app = security_router(true);
3425        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3426        let resp = app.oneshot(req).await.unwrap();
3427
3428        let hsts = resp.headers().get("strict-transport-security").unwrap();
3429        assert!(
3430            hsts.to_str().unwrap().contains("max-age=63072000"),
3431            "HSTS must set 2-year max-age"
3432        );
3433    }
3434
3435    // -- SecurityHeadersConfig validation + override semantics --
3436
3437    /// Build a minimal config with a custom SecurityHeadersConfig and
3438    /// drive it through `check()`. Returns the result so individual
3439    /// tests can assert on success or specific error messages.
3440    fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
3441        let cfg =
3442            McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
3443        cfg.check()
3444    }
3445
3446    #[test]
3447    fn security_headers_config_default_validates() {
3448        check_with_security_headers(SecurityHeadersConfig::default())
3449            .expect("default SecurityHeadersConfig must validate");
3450    }
3451
3452    #[test]
3453    fn security_headers_config_validate_accepts_empty_string() {
3454        // All twelve fields explicitly set to "" -> omit-everything mode.
3455        let h = SecurityHeadersConfig {
3456            x_content_type_options: Some(String::new()),
3457            x_frame_options: Some(String::new()),
3458            cache_control: Some(String::new()),
3459            referrer_policy: Some(String::new()),
3460            cross_origin_opener_policy: Some(String::new()),
3461            cross_origin_resource_policy: Some(String::new()),
3462            cross_origin_embedder_policy: Some(String::new()),
3463            permissions_policy: Some(String::new()),
3464            x_permitted_cross_domain_policies: Some(String::new()),
3465            content_security_policy: Some(String::new()),
3466            x_dns_prefetch_control: Some(String::new()),
3467            strict_transport_security: Some(String::new()),
3468        };
3469        check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
3470    }
3471
3472    #[test]
3473    fn security_headers_config_validate_rejects_bad_value() {
3474        // 0x07 (BEL) is not a valid HTTP header value char.
3475        let h = SecurityHeadersConfig {
3476            referrer_policy: Some("\u{0007}".into()),
3477            ..SecurityHeadersConfig::default()
3478        };
3479        let err = check_with_security_headers(h)
3480            .expect_err("control char in referrer_policy must reject");
3481        let msg = err.to_string();
3482        assert!(
3483            msg.contains("referrer_policy"),
3484            "error must name the offending field, got: {msg}"
3485        );
3486    }
3487
3488    #[test]
3489    fn security_headers_config_validate_rejects_hsts_preload() {
3490        let h = SecurityHeadersConfig {
3491            strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
3492            ..SecurityHeadersConfig::default()
3493        };
3494        let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
3495        let msg = err.to_string();
3496        assert!(
3497            msg.contains("strict_transport_security"),
3498            "error must name the field, got: {msg}"
3499        );
3500        assert!(
3501            msg.to_lowercase().contains("preload"),
3502            "error must mention `preload`, got: {msg}"
3503        );
3504    }
3505
3506    #[test]
3507    fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
3508        // Case-insensitive match.
3509        let h = SecurityHeadersConfig {
3510            strict_transport_security: Some("max-age=600; PRELOAD".into()),
3511            ..SecurityHeadersConfig::default()
3512        };
3513        check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3514    }
3515
3516    #[tokio::test]
3517    async fn security_headers_override_honored() {
3518        // Override X-Frame-Options to SAMEORIGIN.
3519        let h = SecurityHeadersConfig {
3520            x_frame_options: Some("SAMEORIGIN".into()),
3521            ..SecurityHeadersConfig::default()
3522        };
3523        let app = security_router_with(false, h);
3524        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3525        let resp = app.oneshot(req).await.unwrap();
3526        assert_eq!(resp.status(), StatusCode::OK);
3527
3528        let xfo = resp.headers().get("x-frame-options").unwrap();
3529        assert_eq!(xfo, "SAMEORIGIN");
3530    }
3531
3532    #[tokio::test]
3533    async fn security_headers_empty_string_omits() {
3534        // Empty string on referrer-policy -> header absent.
3535        let h = SecurityHeadersConfig {
3536            referrer_policy: Some(String::new()),
3537            ..SecurityHeadersConfig::default()
3538        };
3539        let app = security_router_with(false, h);
3540        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3541        let resp = app.oneshot(req).await.unwrap();
3542        assert_eq!(resp.status(), StatusCode::OK);
3543
3544        assert!(
3545            resp.headers().get("referrer-policy").is_none(),
3546            "Some(\"\") must omit the header"
3547        );
3548        // Other defaults should still be present.
3549        assert_eq!(
3550            resp.headers().get("x-content-type-options").unwrap(),
3551            "nosniff"
3552        );
3553    }
3554
3555    #[tokio::test]
3556    async fn security_headers_hsts_only_when_tls() {
3557        // HSTS override is irrelevant when TLS is off.
3558        let h = SecurityHeadersConfig {
3559            strict_transport_security: Some("max-age=600".into()),
3560            ..SecurityHeadersConfig::default()
3561        };
3562        let app = security_router_with(false, h);
3563        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3564        let resp = app.oneshot(req).await.unwrap();
3565        assert!(
3566            resp.headers().get("strict-transport-security").is_none(),
3567            "HSTS must remain absent on plaintext deployments even with override"
3568        );
3569    }
3570
3571    // -- oauth_token_cache_headers_middleware --
3572
3573    #[cfg(feature = "oauth")]
3574    #[tokio::test]
3575    async fn oauth_token_cache_headers_set_pragma_and_vary() {
3576        let app = axum::Router::new()
3577            .route("/token", axum::routing::post(|| async { "{}" }))
3578            .layer(axum::middleware::from_fn(
3579                oauth_token_cache_headers_middleware,
3580            ));
3581        let req = Request::builder()
3582            .method("POST")
3583            .uri("/token")
3584            .body(Body::from("{}"))
3585            .unwrap();
3586        let resp = app.oneshot(req).await.unwrap();
3587        assert_eq!(resp.status(), StatusCode::OK);
3588
3589        let h = resp.headers();
3590        assert_eq!(
3591            h.get("pragma").unwrap(),
3592            "no-cache",
3593            "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3594        );
3595        let vary_values: Vec<String> = h
3596            .get_all("vary")
3597            .iter()
3598            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3599            .collect();
3600        assert!(
3601            vary_values
3602                .iter()
3603                .any(|v| v.eq_ignore_ascii_case("Authorization")),
3604            "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3605        );
3606    }
3607
3608    #[cfg(feature = "oauth")]
3609    #[tokio::test]
3610    async fn oauth_token_cache_headers_preserve_existing_vary() {
3611        // Simulates a handler/layer that already set `Vary: Accept-Encoding`
3612        // (e.g. compression). Our middleware must APPEND, not REPLACE.
3613        let app = axum::Router::new()
3614            .route(
3615                "/token",
3616                axum::routing::post(|| async {
3617                    axum::response::Response::builder()
3618                        .header("vary", "Accept-Encoding")
3619                        .body(axum::body::Body::from("{}"))
3620                        .unwrap()
3621                }),
3622            )
3623            .layer(axum::middleware::from_fn(
3624                oauth_token_cache_headers_middleware,
3625            ));
3626        let req = Request::builder()
3627            .method("POST")
3628            .uri("/token")
3629            .body(Body::empty())
3630            .unwrap();
3631        let resp = app.oneshot(req).await.unwrap();
3632
3633        let vary: Vec<String> = resp
3634            .headers()
3635            .get_all("vary")
3636            .iter()
3637            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3638            .collect();
3639        assert!(
3640            vary.iter().any(|v| v.contains("Accept-Encoding")),
3641            "must preserve pre-existing Vary value, got {vary:?}"
3642        );
3643        assert!(
3644            vary.iter().any(|v| v.contains("Authorization")),
3645            "must append Authorization to Vary, got {vary:?}"
3646        );
3647    }
3648
3649    // -- version endpoint --
3650
3651    #[test]
3652    fn version_payload_contains_expected_fields() {
3653        let v = version_payload("my-server", "1.2.3");
3654        assert_eq!(v["name"], "my-server");
3655        assert_eq!(v["version"], "1.2.3");
3656        assert!(v["build_git_sha"].is_string());
3657        assert!(v["build_timestamp"].is_string());
3658        assert!(v["rust_version"].is_string());
3659        assert!(v["mcpx_version"].is_string());
3660    }
3661
3662    // -- concurrency limit layer --
3663
3664    #[tokio::test]
3665    async fn concurrency_limit_layer_composes_and_serves() {
3666        // We only assert the layer stack compiles and a single request
3667        // below the cap still succeeds. True back-pressure behaviour
3668        // requires a live HTTP server and is covered by integration tests.
3669        let app = axum::Router::new()
3670            .route("/ok", axum::routing::get(|| async { "ok" }))
3671            .layer(
3672                tower::ServiceBuilder::new()
3673                    .layer(axum::error_handling::HandleErrorLayer::new(
3674                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3675                    ))
3676                    .layer(tower::load_shed::LoadShedLayer::new())
3677                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3678            );
3679        let resp = app
3680            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3681            .await
3682            .unwrap();
3683        assert_eq!(resp.status(), StatusCode::OK);
3684    }
3685
3686    // -- compression layer --
3687
3688    #[tokio::test]
3689    async fn compression_layer_gzip_encodes_response() {
3690        use tower_http::compression::Predicate as _;
3691
3692        let big_body = "a".repeat(4096);
3693        let app = axum::Router::new()
3694            .route(
3695                "/big",
3696                axum::routing::get(move || {
3697                    let body = big_body.clone();
3698                    async move { body }
3699                }),
3700            )
3701            .layer(
3702                tower_http::compression::CompressionLayer::new()
3703                    .gzip(true)
3704                    .br(true)
3705                    .compress_when(
3706                        tower_http::compression::DefaultPredicate::new()
3707                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3708                    ),
3709            );
3710
3711        let req = Request::builder()
3712            .uri("/big")
3713            .header(header::ACCEPT_ENCODING, "gzip")
3714            .body(Body::empty())
3715            .unwrap();
3716        let resp = app.oneshot(req).await.unwrap();
3717        assert_eq!(resp.status(), StatusCode::OK);
3718        assert_eq!(
3719            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3720            "gzip"
3721        );
3722    }
3723
3724    // -- TlsListener handshake timeout --
3725
3726    #[tokio::test]
3727    async fn tls_handshake_timeout_reaps_idle_connections() {
3728        use tokio::io::AsyncReadExt as _;
3729
3730        let _ = rustls::crypto::ring::default_provider().install_default();
3731
3732        // Self-signed cert material on disk (TlsListener::new takes paths).
3733        let key = rcgen::KeyPair::generate().expect("generate key");
3734        let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
3735            .expect("cert params")
3736            .self_signed(&key)
3737            .expect("self-signed cert");
3738        let dir = std::env::temp_dir().join(format!(
3739            "rmcp-server-kit-hs-timeout-{}",
3740            std::time::SystemTime::now()
3741                .duration_since(std::time::UNIX_EPOCH)
3742                .expect("clock after epoch")
3743                .as_nanos()
3744        ));
3745        tokio::fs::create_dir_all(&dir).await.expect("temp dir");
3746        let cert_path = dir.join("server.crt");
3747        let key_path = dir.join("server.key");
3748        tokio::fs::write(&cert_path, cert.pem())
3749            .await
3750            .expect("write cert");
3751        tokio::fs::write(&key_path, key.serialize_pem())
3752            .await
3753            .expect("write key");
3754
3755        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
3756        let tls = TlsListener::new(
3757            listener,
3758            &cert_path,
3759            &key_path,
3760            None,
3761            None,
3762            Duration::from_millis(200),
3763            8, // custom concurrency cap: proves the plumbing end-to-end
3764        )
3765        .expect("tls listener");
3766        let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
3767
3768        // Connect and send NOTHING: the handshake worker must time out
3769        // after 200ms and drop the stream, which the client observes as
3770        // EOF or a reset well within the 2s deadline.
3771        let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
3772        let mut buf = [0_u8; 16];
3773        let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
3774            .await
3775            .expect("server must reap the idle handshake within its timeout");
3776        match read {
3777            Ok(0) | Err(_) => {} // EOF or reset: connection was dropped.
3778            Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
3779        }
3780
3781        drop(tls);
3782    }
3783}