Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::{IpAddr, SocketAddr},
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{
12    body::Body,
13    extract::{ConnectInfo, Request},
14    middleware::Next,
15    response::IntoResponse,
16};
17use rmcp::{
18    ServerHandler,
19    transport::streamable_http_server::{
20        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21    },
22};
23use rustls::RootCertStore;
24use tokio::{
25    net::TcpListener,
26    sync::{Semaphore, mpsc},
27};
28use tokio_util::sync::CancellationToken;
29
30use crate::{
31    auth::{
32        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
33        build_rate_limiter, extract_mtls_identity,
34    },
35    bounded_limiter::BoundedKeyedLimiter,
36    error::McpxError,
37    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
38    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
39};
40
41/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
42/// at the public API boundary, flattening the chain via the alternate
43/// formatter so callers see the full causal path.
44#[allow(
45    clippy::needless_pass_by_value,
46    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
47)]
48fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
49    McpxError::Startup(format!("{e:#}"))
50}
51
52/// Map a `std::io::Error` produced during server startup into a public
53/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
54/// `From` impl here because startup-phase IO errors (bind, listener) are
55/// semantically distinct from request-time IO errors and should surface
56/// the originating operation in the message.
57#[allow(
58    clippy::needless_pass_by_value,
59    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
60)]
61fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
62    McpxError::Startup(format!("{op}: {e}"))
63}
64
65/// Async readiness check callback for the `/readyz` endpoint.
66///
67/// Returns a JSON object with at least a `"ready"` boolean.
68/// When `ready` is false, the endpoint returns HTTP 503.
69pub type ReadinessCheck =
70    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
71
72/// Direct socket peer address of the current HTTP/TLS connection.
73///
74/// Inserted as a request extension into every request served by [`serve`] —
75/// on both the plain and the TLS listener — and extractable in any axum
76/// handler, including routes mounted via
77/// [`McpServerConfig::with_extra_router`] (which bypass auth/RBAC and
78/// therefore often need the peer address for their own protection, e.g.
79/// per-IP rate limiting).
80///
81/// The same address is also mirrored into
82/// [`axum::extract::ConnectInfo`]`<SocketAddr>` on the TLS listener, so
83/// third-party middleware that expects the stock axum extension (e.g.
84/// per-IP rate-limit key extractors) works unmodified under TLS.
85///
86/// # Semantics
87///
88/// - **Direct peer only.** This is the socket's remote address. Behind an
89///   L4/L7 proxy or load balancer it is the proxy's address; the framework
90///   performs **no** `X-Forwarded-For` / `Forwarded` interpretation.
91/// - **Available on HTTP and TLS** transports alike ([`serve`]).
92/// - **Absent under [`serve_stdio`]** — a stdio session has no network
93///   peer (stdio bypasses the HTTP stack entirely).
94/// - The separate Prometheus metrics listener (feature `metrics`) is a
95///   different router and does not carry this extension.
96///
97/// # Privacy
98///
99/// `PeerAddr` exposes raw peer network metadata. The framework deliberately
100/// never logs it on its own; whether to log or persist peer addresses is
101/// application policy.
102///
103/// # Example
104///
105/// ```no_run
106/// use axum::{Router, routing::get};
107/// use rmcp_server_kit::transport::{McpServerConfig, PeerAddr};
108///
109/// async fn whoami(peer: PeerAddr) -> String {
110///     peer.addr.ip().to_string()
111/// }
112///
113/// let _config = McpServerConfig::new("127.0.0.1:8443", "my-server", "1.0.0")
114///     .with_extra_router(Router::new().route("/whoami", get(whoami)));
115/// ```
116#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
117#[non_exhaustive]
118pub struct PeerAddr {
119    /// Direct socket peer of this connection.
120    pub addr: SocketAddr,
121}
122
123impl PeerAddr {
124    /// Construct a new [`PeerAddr`]. Framework-internal: downstream code
125    /// receives `PeerAddr` via request extensions and never constructs it.
126    #[must_use]
127    pub(crate) const fn new(addr: SocketAddr) -> Self {
128        Self { addr }
129    }
130}
131
132/// Extract [`PeerAddr`] from request extensions.
133///
134/// # Rejection
135///
136/// Responds `500 Internal Server Error` when the extension is missing.
137/// A missing `PeerAddr` means the handler is not running under [`serve`]
138/// (e.g. the router was mounted on a hand-rolled listener) — a wiring
139/// bug, not a client error.
140impl<S: Send + Sync> axum::extract::FromRequestParts<S> for PeerAddr {
141    type Rejection = (axum::http::StatusCode, &'static str);
142
143    async fn from_request_parts(
144        parts: &mut axum::http::request::Parts,
145        _state: &S,
146    ) -> Result<Self, Self::Rejection> {
147        parts.extensions.get::<Self>().copied().ok_or((
148            axum::http::StatusCode::INTERNAL_SERVER_ERROR,
149            "peer address unavailable: not running under rmcp-server-kit serve()",
150        ))
151    }
152}
153
154/// Per-header overrides for the OWASP security headers emitted by the
155/// global response middleware.
156///
157/// Each field follows a three-state semantic:
158///
159/// | Value         | Behaviour                                                |
160/// |---------------|----------------------------------------------------------|
161/// | `None`        | Use the built-in default (current behaviour).            |
162/// | `Some("")`    | **Omit** the header entirely from responses.             |
163/// | `Some(value)` | Emit `header: value`. Validated at config-load time.     |
164///
165/// All non-empty values are validated via
166/// [`axum::http::HeaderValue::from_str`] inside
167/// [`McpServerConfig::validate`]; invalid values fail fast before the
168/// server starts accepting traffic.
169///
170/// `Strict-Transport-Security` has an additional rule: the substring
171/// `preload` (case-insensitive) is rejected. Operators who want to
172/// commit to the HSTS preload list must do so via a future explicit
173/// builder method, not by smuggling it through this knob.
174#[derive(Debug, Clone, Default)]
175#[non_exhaustive]
176pub struct SecurityHeadersConfig {
177    /// Override for `X-Content-Type-Options`. Default: `nosniff`.
178    pub x_content_type_options: Option<String>,
179    /// Override for `X-Frame-Options`. Default: `deny`.
180    pub x_frame_options: Option<String>,
181    /// Override for `Cache-Control`. Default: `no-store, max-age=0`.
182    pub cache_control: Option<String>,
183    /// Override for `Referrer-Policy`. Default: `no-referrer`.
184    pub referrer_policy: Option<String>,
185    /// Override for `Cross-Origin-Opener-Policy`. Default: `same-origin`.
186    pub cross_origin_opener_policy: Option<String>,
187    /// Override for `Cross-Origin-Resource-Policy`. Default: `same-origin`.
188    pub cross_origin_resource_policy: Option<String>,
189    /// Override for `Cross-Origin-Embedder-Policy`. Default: `require-corp`.
190    pub cross_origin_embedder_policy: Option<String>,
191    /// Override for `Permissions-Policy`. Default:
192    /// `accelerometer=(), camera=(), geolocation=(), microphone=()`.
193    pub permissions_policy: Option<String>,
194    /// Override for `X-Permitted-Cross-Domain-Policies`. Default: `none`.
195    pub x_permitted_cross_domain_policies: Option<String>,
196    /// Override for `Content-Security-Policy`. Default:
197    /// `default-src 'none'; frame-ancestors 'none'`.
198    pub content_security_policy: Option<String>,
199    /// Override for `X-DNS-Prefetch-Control`. Default: `off`.
200    pub x_dns_prefetch_control: Option<String>,
201    /// Override for `Strict-Transport-Security`. Default (TLS only):
202    /// `max-age=63072000; includeSubDomains`. Only emitted when TLS is
203    /// active; the override is ignored on plaintext deployments. The
204    /// substring `preload` (any case) is rejected by the validator.
205    pub strict_transport_security: Option<String>,
206}
207
208/// Configuration for the MCP server.
209#[allow(
210    missing_debug_implementations,
211    reason = "contains callback/trait objects that don't impl Debug"
212)]
213#[allow(
214    clippy::struct_excessive_bools,
215    reason = "server configuration naturally has many boolean feature flags"
216)]
217#[non_exhaustive]
218pub struct McpServerConfig {
219    /// Socket address the MCP HTTP server binds to.
220    #[deprecated(
221        since = "0.13.0",
222        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
223    )]
224    pub bind_addr: String,
225    /// Server name advertised via MCP `initialize`.
226    #[deprecated(
227        since = "0.13.0",
228        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
229    )]
230    pub name: String,
231    /// Server version advertised via MCP `initialize`.
232    #[deprecated(
233        since = "0.13.0",
234        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
235    )]
236    pub version: String,
237    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
238    #[deprecated(
239        since = "0.13.0",
240        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
241    )]
242    pub tls_cert_path: Option<PathBuf>,
243    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
244    #[deprecated(
245        since = "0.13.0",
246        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
247    )]
248    pub tls_key_path: Option<PathBuf>,
249    /// Optional authentication config. When `Some` and `enabled`, auth
250    /// is enforced on `/mcp`. `/healthz` is always open.
251    #[deprecated(
252        since = "0.13.0",
253        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
254    )]
255    pub auth: Option<AuthConfig>,
256    /// Optional RBAC policy. When present and enabled, tool calls are
257    /// checked against the policy after authentication.
258    #[deprecated(
259        since = "0.13.0",
260        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
261    )]
262    pub rbac: Option<Arc<RbacPolicy>>,
263    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
264    /// When empty and `public_url` is set, the origin is auto-derived from
265    /// the public URL. When both are empty, only requests with no Origin
266    /// header are accepted.
267    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
268    #[deprecated(
269        since = "0.13.0",
270        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
271    )]
272    pub allowed_origins: Vec<String>,
273    /// Maximum tool invocations per source IP per minute.
274    /// When set, enforced on every `tools/call` request.
275    #[deprecated(
276        since = "0.13.0",
277        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
278    )]
279    pub tool_rate_limit: Option<u32>,
280    /// Burst capacity for the tool rate limiter: maximum `tools/call`
281    /// requests admitted back-to-back before the sustained
282    /// [`tool_rate_limit`](Self::tool_rate_limit) rate applies. `None`
283    /// (default) keeps governor's default of burst = rate. Requires
284    /// `tool_rate_limit` to be set; must be greater than zero.
285    #[deprecated(
286        since = "1.12.0",
287        note = "use McpServerConfig::with_tool_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
288    )]
289    pub tool_rate_limit_burst: Option<u32>,
290    /// Maximum requests per source IP per minute for routes merged via
291    /// [`with_extra_router`](Self::with_extra_router). Opt-in: `None`
292    /// (the default) installs no limiter. Startup-only (not
293    /// hot-reloadable via [`ReloadHandle`]).
294    ///
295    /// Keyed by the **direct socket peer** ([`PeerAddr`] semantics — no
296    /// `X-Forwarded-For` interpretation): behind a reverse proxy all
297    /// clients share the proxy's bucket, and IPv6 single-host address
298    /// rotation can evade per-IP keying. Treat this as an abuse speed
299    /// bump for unauthenticated application endpoints, not tenant
300    /// isolation. On limit: HTTP 429 with a plain-text body, matching
301    /// the tool/auth limiters.
302    #[deprecated(
303        since = "1.11.0",
304        note = "use McpServerConfig::with_extra_route_rate_limit(); direct field access will become pub(crate) in a future major release"
305    )]
306    pub extra_route_rate_limit: Option<u32>,
307    /// Burst capacity for the extra-route limiter: maximum requests
308    /// admitted back-to-back before the sustained
309    /// [`extra_route_rate_limit`](Self::extra_route_rate_limit) rate
310    /// applies. `None` (default) keeps governor's default of
311    /// burst = rate. Requires `extra_route_rate_limit` to be set; must
312    /// be greater than zero.
313    #[deprecated(
314        since = "1.12.0",
315        note = "use McpServerConfig::with_extra_route_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
316    )]
317    pub extra_route_rate_limit_burst: Option<u32>,
318    /// Optional readiness probe for `/readyz`.
319    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
320    #[deprecated(
321        since = "0.13.0",
322        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
323    )]
324    pub readiness_check: Option<ReadinessCheck>,
325    /// Maximum request body size in bytes. Default: 1 MiB.
326    /// Protects against oversized payloads causing OOM.
327    #[deprecated(
328        since = "0.13.0",
329        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
330    )]
331    pub max_request_body: usize,
332    /// Request processing timeout. Default: 120s.
333    /// Requests exceeding this duration receive 408 Request Timeout.
334    #[deprecated(
335        since = "0.13.0",
336        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
337    )]
338    pub request_timeout: Duration,
339    /// Graceful shutdown timeout. Default: 30s.
340    /// After the shutdown signal, in-flight requests have this long to finish.
341    #[deprecated(
342        since = "0.13.0",
343        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
344    )]
345    pub shutdown_timeout: Duration,
346    /// Idle timeout for MCP sessions. Sessions with no activity for this
347    /// duration are closed automatically. Default: 20 minutes.
348    #[deprecated(
349        since = "0.13.0",
350        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
351    )]
352    pub session_idle_timeout: Duration,
353    /// Interval for SSE keep-alive pings. Prevents proxies and load
354    /// balancers from killing idle connections. Default: 15 seconds.
355    #[deprecated(
356        since = "0.13.0",
357        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
358    )]
359    pub sse_keep_alive: Duration,
360    /// Callback invoked once the server is built, delivering a
361    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
362    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
363    #[deprecated(
364        since = "0.13.0",
365        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
366    )]
367    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
368    /// Additional application-specific routes merged into the top-level
369    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
370    /// so the application is responsible for its own auth on them.
371    /// Handlers can extract [`PeerAddr`] (or
372    /// [`axum::extract::ConnectInfo`]`<SocketAddr>` for third-party
373    /// middleware compatibility) regardless of whether TLS is enabled.
374    #[deprecated(
375        since = "0.13.0",
376        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
377    )]
378    pub extra_router: Option<axum::Router>,
379    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
380    /// When set, OAuth metadata endpoints advertise this URL instead of
381    /// the listen address. Required when binding `0.0.0.0` behind a
382    /// reverse proxy or inside a container.
383    #[deprecated(
384        since = "0.13.0",
385        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
386    )]
387    pub public_url: Option<String>,
388    /// Log inbound HTTP request headers at DEBUG level.
389    /// Sensitive values remain redacted.
390    #[deprecated(
391        since = "0.13.0",
392        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
393    )]
394    pub log_request_headers: bool,
395    /// Enable gzip/br response compression on MCP responses.
396    /// Defaults to `false` to preserve existing behaviour.
397    #[deprecated(
398        since = "0.13.0",
399        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
400    )]
401    pub compression_enabled: bool,
402    /// Minimum response body size (in bytes) before compression kicks in.
403    /// Only used when `compression_enabled` is true. Default: 1024.
404    #[deprecated(
405        since = "0.13.0",
406        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
407    )]
408    pub compression_min_size: u16,
409    /// Global cap on in-flight HTTP requests across the whole server.
410    /// When `Some`, requests over the cap receive 503 Service Unavailable
411    /// via `tower::load_shed`. Default: `None` (unlimited).
412    #[deprecated(
413        since = "0.13.0",
414        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
415    )]
416    pub max_concurrent_requests: Option<usize>,
417    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
418    /// configured and `enabled`. Default: `false`.
419    #[deprecated(
420        since = "0.13.0",
421        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
422    )]
423    pub admin_enabled: bool,
424    /// RBAC role required to access admin endpoints. Default: `"admin"`.
425    #[deprecated(
426        since = "0.13.0",
427        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
428    )]
429    pub admin_role: String,
430    /// Enable Prometheus metrics endpoint on a separate listener.
431    /// Requires the `metrics` crate feature.
432    #[cfg(feature = "metrics")]
433    #[deprecated(
434        since = "0.13.0",
435        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
436    )]
437    pub metrics_enabled: bool,
438    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
439    #[cfg(feature = "metrics")]
440    #[deprecated(
441        since = "0.13.0",
442        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
443    )]
444    pub metrics_bind: String,
445    /// Per-header overrides for the OWASP security headers emitted by
446    /// the global response middleware. See [`SecurityHeadersConfig`]
447    /// for the three-state semantic and validation rules.
448    #[deprecated(
449        since = "1.5.0",
450        note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
451    )]
452    pub security_headers: SecurityHeadersConfig,
453    /// Per-handshake deadline on the TLS accept path. Idle or slow-loris
454    /// connections are dropped once it elapses. Default: 10s.
455    ///
456    /// Startup-only: bound at listener construction, NOT hot-reloadable
457    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
458    #[deprecated(
459        since = "1.9.0",
460        note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
461    )]
462    pub tls_handshake_timeout: Duration,
463    /// Cap on concurrently in-flight TLS handshakes. At saturation the
464    /// acceptor stops pulling new connections from the kernel backlog
465    /// (backpressure) instead of accepting and dropping. Default: 256.
466    ///
467    /// Startup-only: bound at listener construction, NOT hot-reloadable
468    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
469    #[deprecated(
470        since = "1.9.0",
471        note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
472    )]
473    pub max_concurrent_tls_handshakes: usize,
474}
475
476/// Marker that wraps a value proven to satisfy its validation
477/// contract.
478///
479/// The only way to obtain `Validated<McpServerConfig>` is by calling
480/// [`McpServerConfig::validate`], which is the contract enforced at
481/// the type level by [`serve`] and [`serve_with_listener`]. The
482/// inner field is private, so downstream code cannot bypass
483/// validation by hand-constructing the wrapper.
484///
485/// Use [`Validated::as_inner`] for read-only borrowing. To mutate,
486/// recover the raw value with [`Validated::into_inner`] and
487/// re-validate.
488///
489/// # Example
490///
491/// ```no_run
492/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
493/// use rmcp::handler::server::ServerHandler;
494/// use rmcp::model::{ServerCapabilities, ServerInfo};
495///
496/// #[derive(Clone)]
497/// struct H;
498/// impl ServerHandler for H {
499///     fn get_info(&self) -> ServerInfo {
500///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
501///     }
502/// }
503///
504/// # async fn example() -> rmcp_server_kit::Result<()> {
505/// let config: Validated<McpServerConfig> =
506///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
507/// serve(config, || H).await
508/// # }
509/// ```
510///
511/// Forgetting `.validate()?` is a compile error:
512///
513/// ```compile_fail
514/// use rmcp_server_kit::transport::{McpServerConfig, serve};
515/// use rmcp::handler::server::ServerHandler;
516/// use rmcp::model::{ServerCapabilities, ServerInfo};
517///
518/// #[derive(Clone)]
519/// struct H;
520/// impl ServerHandler for H {
521///     fn get_info(&self) -> ServerInfo {
522///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
523///     }
524/// }
525///
526/// # async fn example() -> rmcp_server_kit::Result<()> {
527/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
528/// // Missing `.validate()?` -> mismatched types: expected
529/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
530/// serve(config, || H).await
531/// # }
532/// ```
533#[allow(
534    missing_debug_implementations,
535    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
536)]
537pub struct Validated<T>(T);
538
539impl<T> std::fmt::Debug for Validated<T> {
540    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541        f.debug_struct("Validated").finish_non_exhaustive()
542    }
543}
544
545impl<T> Validated<T> {
546    /// Borrow the inner value.
547    #[must_use]
548    pub fn as_inner(&self) -> &T {
549        &self.0
550    }
551
552    /// Recover the raw value, discarding the validation proof.
553    ///
554    /// Re-validate before re-using the value with [`serve`] or
555    /// [`serve_with_listener`].
556    #[must_use]
557    pub fn into_inner(self) -> T {
558        self.0
559    }
560}
561
562#[allow(
563    deprecated,
564    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
565)]
566impl McpServerConfig {
567    /// Create a new server configuration with the given bind address,
568    /// server name, and version. All other fields use safe defaults.
569    ///
570    /// Use the chainable `with_*` / `enable_*` builder methods to
571    /// customize. Call [`McpServerConfig::validate`] to obtain a
572    /// [`Validated<McpServerConfig>`] proof token, which is required by
573    /// [`serve`] and [`serve_with_listener`].
574    #[must_use]
575    pub fn new(
576        bind_addr: impl Into<String>,
577        name: impl Into<String>,
578        version: impl Into<String>,
579    ) -> Self {
580        Self {
581            bind_addr: bind_addr.into(),
582            name: name.into(),
583            version: version.into(),
584            tls_cert_path: None,
585            tls_key_path: None,
586            auth: None,
587            rbac: None,
588            allowed_origins: Vec::new(),
589            tool_rate_limit: None,
590            readiness_check: None,
591            max_request_body: 1024 * 1024,
592            request_timeout: Duration::from_mins(2),
593            shutdown_timeout: Duration::from_secs(30),
594            session_idle_timeout: Duration::from_mins(20),
595            sse_keep_alive: Duration::from_secs(15),
596            on_reload_ready: None,
597            extra_router: None,
598            public_url: None,
599            log_request_headers: false,
600            compression_enabled: false,
601            compression_min_size: 1024,
602            max_concurrent_requests: None,
603            admin_enabled: false,
604            admin_role: "admin".to_owned(),
605            #[cfg(feature = "metrics")]
606            metrics_enabled: false,
607            #[cfg(feature = "metrics")]
608            metrics_bind: "127.0.0.1:9090".into(),
609            security_headers: SecurityHeadersConfig::default(),
610            tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
611            max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
612            extra_route_rate_limit: None,
613            tool_rate_limit_burst: None,
614            extra_route_rate_limit_burst: None,
615        }
616    }
617
618    // ---------------------------------------------------------------
619    // Builder methods (fluent, consume + return self).
620    //
621    // Each method is `#[must_use]` because dropping the returned
622    // `McpServerConfig` discards the configuration change.
623    // ---------------------------------------------------------------
624
625    /// Attach an authentication configuration. Required for
626    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
627    #[must_use]
628    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
629        self.auth = Some(auth);
630        self
631    }
632
633    /// Override one or more of the OWASP security headers emitted on
634    /// every response. See [`SecurityHeadersConfig`] for the three-state
635    /// semantic (`None` = default, `Some("")` = omit, `Some(v)` =
636    /// override). Values are validated by [`Self::validate`].
637    #[must_use]
638    pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
639        self.security_headers = headers;
640        self
641    }
642
643    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
644    /// final port is only known after pre-binding an ephemeral listener
645    /// (tests, dynamic-port deployments).
646    #[must_use]
647    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
648        self.bind_addr = addr.into();
649        self
650    }
651
652    /// Attach an RBAC policy. Tool calls are checked against the policy
653    /// after authentication.
654    #[must_use]
655    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
656        self.rbac = Some(rbac);
657        self
658    }
659
660    /// Configure TLS by providing the certificate and private key paths
661    /// (PEM). Both must be readable at startup. Without this call, the
662    /// server runs plain HTTP.
663    #[must_use]
664    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
665        self.tls_cert_path = Some(cert_path.into());
666        self.tls_key_path = Some(key_path.into());
667        self
668    }
669
670    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
671    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
672    /// a container so OAuth metadata and auto-derived origins resolve correctly.
673    #[must_use]
674    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
675        self.public_url = Some(url.into());
676        self
677    }
678
679    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
680    /// When empty and [`with_public_url`](Self::with_public_url) is set,
681    /// the origin is auto-derived.
682    #[must_use]
683    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
684    where
685        I: IntoIterator<Item = S>,
686        S: Into<String>,
687    {
688        self.allowed_origins = origins.into_iter().map(Into::into).collect();
689        self
690    }
691
692    /// Merge an additional axum router at the top level. Routes added
693    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
694    /// for its own protection.
695    ///
696    /// To support that protection (e.g. per-IP rate limiting on
697    /// unauthenticated endpoints), every request served by [`serve`]
698    /// carries the client peer address regardless of whether TLS is
699    /// enabled: extract the framework-owned [`PeerAddr`] in your
700    /// handlers, or rely on [`axum::extract::ConnectInfo`]`<SocketAddr>`
701    /// for stock third-party middleware (e.g. per-IP rate-limit key
702    /// extractors). Neither extension exists under [`serve_stdio`],
703    /// which has no network peer.
704    #[must_use]
705    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
706        self.extra_router = Some(router);
707        self
708    }
709
710    /// Install an async readiness probe for `/readyz`. Without this call,
711    /// `/readyz` mirrors `/healthz` (always 200 OK).
712    #[must_use]
713    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
714        self.readiness_check = Some(check);
715        self
716    }
717
718    /// Override the maximum request body (bytes). Must be `> 0`.
719    /// Default: 1 MiB.
720    #[must_use]
721    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
722        self.max_request_body = bytes;
723        self
724    }
725
726    /// Override the per-request processing timeout. Default: 2 minutes.
727    #[must_use]
728    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
729        self.request_timeout = timeout;
730        self
731    }
732
733    /// Override the graceful shutdown grace period. Default: 30 seconds.
734    #[must_use]
735    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
736        self.shutdown_timeout = timeout;
737        self
738    }
739
740    /// Override the MCP session idle timeout. Default: 20 minutes.
741    #[must_use]
742    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
743        self.session_idle_timeout = timeout;
744        self
745    }
746
747    /// Override the SSE keep-alive interval. Default: 15 seconds.
748    #[must_use]
749    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
750        self.sse_keep_alive = interval;
751        self
752    }
753
754    /// Cap the global number of in-flight HTTP requests via
755    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
756    /// Default: unlimited.
757    #[must_use]
758    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
759        self.max_concurrent_requests = Some(limit);
760        self
761    }
762
763    /// Override the per-handshake deadline on the TLS accept path.
764    /// Idle or slow-loris connections are dropped once it elapses.
765    /// Default: 10s. Must be greater than zero.
766    ///
767    /// Startup-only: the value is bound at listener construction and is
768    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
769    /// is configured via [`Self::with_tls`].
770    #[must_use]
771    pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
772        self.tls_handshake_timeout = timeout;
773        self
774    }
775
776    /// Override the cap on concurrently in-flight TLS handshakes. At
777    /// saturation the acceptor stops pulling new connections from the
778    /// kernel backlog (backpressure) instead of accepting and dropping.
779    /// Default: 256. Must be greater than zero.
780    ///
781    /// Startup-only: the value is bound at listener construction and is
782    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
783    /// is configured via [`Self::with_tls`].
784    #[must_use]
785    pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
786        self.max_concurrent_tls_handshakes = limit;
787        self
788    }
789
790    /// Cap tool invocations per source IP per minute. Enforced on every
791    /// `tools/call` request.
792    #[must_use]
793    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
794        self.tool_rate_limit = Some(per_minute);
795        self
796    }
797
798    /// Cap requests per source IP per minute on routes merged via
799    /// [`with_extra_router`](Self::with_extra_router) — the natural
800    /// protection for unauthenticated application endpoints (OAuth
801    /// callbacks, registration, …) that bypass auth/RBAC by design.
802    ///
803    /// Must be greater than zero (validated by
804    /// [`validate`](Self::validate)). Startup-only. See the
805    /// `extra_route_rate_limit` field docs for keying semantics and
806    /// caveats (direct peer only, IPv6 rotation, proxy collapse,
807    /// bounded-memory shared-fate under key spray).
808    #[must_use]
809    pub fn with_extra_route_rate_limit(mut self, per_minute: u32) -> Self {
810        self.extra_route_rate_limit = Some(per_minute);
811        self
812    }
813
814    /// Set the burst capacity for the tool rate limiter (bucket size;
815    /// the sustained rate stays [`with_tool_rate_limit`](Self::with_tool_rate_limit)).
816    /// Requires the tool rate limit to be set; must be greater than zero
817    /// (both validated by [`validate`](Self::validate)).
818    #[must_use]
819    pub fn with_tool_rate_limit_burst(mut self, burst: u32) -> Self {
820        self.tool_rate_limit_burst = Some(burst);
821        self
822    }
823
824    /// Set the burst capacity for the extra-route rate limiter (bucket
825    /// size; the sustained rate stays
826    /// [`with_extra_route_rate_limit`](Self::with_extra_route_rate_limit)).
827    /// Requires the extra-route rate limit to be set; must be greater
828    /// than zero (both validated by [`validate`](Self::validate)).
829    #[must_use]
830    pub fn with_extra_route_rate_limit_burst(mut self, burst: u32) -> Self {
831        self.extra_route_rate_limit_burst = Some(burst);
832        self
833    }
834
835    /// Register a callback that receives the [`ReloadHandle`] after the
836    /// server is built. Use it to wire SIGHUP-style hot reloads of API
837    /// keys and RBAC policy.
838    #[must_use]
839    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
840    where
841        F: FnOnce(ReloadHandle) + Send + 'static,
842    {
843        self.on_reload_ready = Some(Box::new(callback));
844        self
845    }
846
847    /// Enable gzip/brotli response compression on MCP responses.
848    /// `min_size` is the smallest body size (bytes) eligible for
849    /// compression. Default min size: 1024.
850    #[must_use]
851    pub fn enable_compression(mut self, min_size: u16) -> Self {
852        self.compression_enabled = true;
853        self.compression_min_size = min_size;
854        self
855    }
856
857    /// Enable `/admin/*` diagnostic endpoints. Requires
858    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
859    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
860    /// role gate (default: `"admin"`).
861    #[must_use]
862    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
863        self.admin_enabled = true;
864        self.admin_role = role.into();
865        self
866    }
867
868    /// Log inbound HTTP request headers at DEBUG level. Sensitive
869    /// values remain redacted by the logging layer.
870    #[must_use]
871    pub fn enable_request_header_logging(mut self) -> Self {
872        self.log_request_headers = true;
873        self
874    }
875
876    /// Enable the Prometheus metrics listener on `bind` (e.g.
877    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
878    #[cfg(feature = "metrics")]
879    #[must_use]
880    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
881        self.metrics_enabled = true;
882        self.metrics_bind = bind.into();
883        self
884    }
885
886    /// Validate the configuration and consume `self`, returning a
887    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
888    /// and [`serve_with_listener`]. This is the only way to construct
889    /// `Validated<McpServerConfig>`, so the type system guarantees
890    /// validation has run before the server starts.
891    ///
892    /// Checks:
893    ///
894    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
895    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
896    ///    be unset.
897    /// 3. `bind_addr` must parse as a [`SocketAddr`].
898    /// 4. `public_url`, when set, must start with `http://` or `https://`.
899    /// 5. Each entry in `allowed_origins` must start with `http://` or
900    ///    `https://`.
901    /// 6. `max_request_body` must be greater than zero.
902    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
903    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
904    ///    `proxy.token_url`, `proxy.introspection_url`,
905    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
906    ///    and use the `https` scheme. Set
907    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
908    ///    targets (strongly discouraged in production - see the field-level
909    ///    docs for the threat model).
910    ///
911    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
912    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
913    ///
914    /// # Errors
915    ///
916    /// Returns [`McpxError::Config`] with a human-readable message on
917    /// the first validation failure.
918    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
919        self.check()?;
920        Ok(Validated(self))
921    }
922
923    /// Validate the burst knobs: every burst must be greater than zero
924    /// when set, and the two top-level bursts require their base limiter
925    /// to be configured. The auth bursts
926    /// (`RateLimitConfig::{burst, pre_auth_burst}`) have no orphan rule:
927    /// their base rates always resolve (`max_attempts_per_minute` is
928    /// mandatory; the pre-auth base derives from it when unset).
929    fn check_burst_knobs(&self) -> Result<(), McpxError> {
930        if self.tool_rate_limit_burst == Some(0) {
931            return Err(McpxError::Config(
932                "tool_rate_limit_burst must be greater than zero".into(),
933            ));
934        }
935        if self.extra_route_rate_limit_burst == Some(0) {
936            return Err(McpxError::Config(
937                "extra_route_rate_limit_burst must be greater than zero".into(),
938            ));
939        }
940        if self.tool_rate_limit_burst.is_some() && self.tool_rate_limit.is_none() {
941            return Err(McpxError::Config(
942                "tool_rate_limit_burst requires tool_rate_limit to be set".into(),
943            ));
944        }
945        if self.extra_route_rate_limit_burst.is_some() && self.extra_route_rate_limit.is_none() {
946            return Err(McpxError::Config(
947                "extra_route_rate_limit_burst requires extra_route_rate_limit to be set".into(),
948            ));
949        }
950        if let Some(rl) = self.auth.as_ref().and_then(|a| a.rate_limit.as_ref()) {
951            if rl.burst == Some(0) {
952                return Err(McpxError::Config(
953                    "auth rate_limit.burst must be greater than zero".into(),
954                ));
955            }
956            if rl.pre_auth_burst == Some(0) {
957                return Err(McpxError::Config(
958                    "auth rate_limit.pre_auth_burst must be greater than zero".into(),
959                ));
960            }
961        }
962        Ok(())
963    }
964
965    /// Run the validation checks without consuming `self`. Used by
966    /// internal call sites (e.g. tests) that need to inspect a config
967    /// without taking ownership.
968    fn check(&self) -> Result<(), McpxError> {
969        // 1. admin <-> auth dependency. Mirrors the runtime check in
970        //    `build_app_router`: admin endpoints require an auth state,
971        //    which is built only when `auth` is `Some` *and* `enabled`.
972        if self.admin_enabled {
973            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
974            if !auth_enabled {
975                return Err(McpxError::Config(
976                    "admin_enabled=true requires auth to be configured and enabled".into(),
977                ));
978            }
979        }
980
981        // 2. TLS cert / key must be paired
982        match (&self.tls_cert_path, &self.tls_key_path) {
983            (Some(_), None) => {
984                return Err(McpxError::Config(
985                    "tls_cert_path is set but tls_key_path is missing".into(),
986                ));
987            }
988            (None, Some(_)) => {
989                return Err(McpxError::Config(
990                    "tls_key_path is set but tls_cert_path is missing".into(),
991                ));
992            }
993            _ => {}
994        }
995
996        // 3. bind_addr parses
997        if self.bind_addr.parse::<SocketAddr>().is_err() {
998            return Err(McpxError::Config(format!(
999                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
1000                self.bind_addr
1001            )));
1002        }
1003
1004        // 4. public_url scheme
1005        if let Some(ref url) = self.public_url
1006            && !(url.starts_with("http://") || url.starts_with("https://"))
1007        {
1008            return Err(McpxError::Config(format!(
1009                "public_url {url:?} must start with http:// or https://"
1010            )));
1011        }
1012
1013        // 5. allowed_origins scheme
1014        for origin in &self.allowed_origins {
1015            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
1016                return Err(McpxError::Config(format!(
1017                    "allowed_origins entry {origin:?} must start with http:// or https://"
1018                )));
1019            }
1020        }
1021
1022        // 6. max_request_body > 0
1023        if self.max_request_body == 0 {
1024            return Err(McpxError::Config(
1025                "max_request_body must be greater than zero".into(),
1026            ));
1027        }
1028
1029        // 6b. extra_route_rate_limit, when set, must be > 0. Unlike the
1030        // legacy tool_rate_limit (which clamps 0 to its default at
1031        // construction), new knobs fail fast on nonsensical values.
1032        if self.extra_route_rate_limit == Some(0) {
1033            return Err(McpxError::Config(
1034                "extra_route_rate_limit must be greater than zero".into(),
1035            ));
1036        }
1037
1038        // 6c. Burst knobs (extracted helper).
1039        self.check_burst_knobs()?;
1040
1041        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
1042        #[cfg(feature = "oauth")]
1043        if let Some(auth_cfg) = &self.auth
1044            && let Some(oauth_cfg) = &auth_cfg.oauth
1045        {
1046            oauth_cfg.validate()?;
1047        }
1048
1049        // 8. Security-header overrides parse as valid HTTP header values,
1050        //    and HSTS does not smuggle in a `preload` directive.
1051        validate_security_headers(&self.security_headers)?;
1052
1053        // 9. max_concurrent_requests must be > 0 when set. Zero would
1054        //    deadlock the global concurrency limiter and reject every
1055        //    request. Mirrors the TOML-side check in `src/config.rs`.
1056        if let Some(0) = self.max_concurrent_requests {
1057            return Err(McpxError::Config(
1058                "max_concurrent_requests must be greater than zero when set".into(),
1059            ));
1060        }
1061
1062        // 10. Auth rate-limit `max_tracked_keys` must be > 0. A zero cap
1063        //     would force `BoundedKeyedLimiter` to evict on every insert
1064        //     and effectively disable rate limiting.
1065        if let Some(auth_cfg) = &self.auth
1066            && let Some(rl) = &auth_cfg.rate_limit
1067            && rl.max_tracked_keys == 0
1068        {
1069            return Err(McpxError::Config(
1070                "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
1071            ));
1072        }
1073
1074        // 11. tls_handshake_timeout must be > 0. A zero deadline would
1075        //     reap every handshake before it could complete, rejecting
1076        //     all TLS connections. Mirrors the TOML-side check in
1077        //     `src/config.rs`.
1078        if self.tls_handshake_timeout == Duration::ZERO {
1079            return Err(McpxError::Config(
1080                "tls_handshake_timeout must be greater than zero".into(),
1081            ));
1082        }
1083
1084        // 12. max_concurrent_tls_handshakes must be > 0. A zero-permit
1085        //     semaphore would never admit a handshake, deadlocking the
1086        //     TLS accept path. Mirrors the TOML-side check in
1087        //     `src/config.rs`.
1088        if self.max_concurrent_tls_handshakes == 0 {
1089            return Err(McpxError::Config(
1090                "max_concurrent_tls_handshakes must be greater than zero".into(),
1091            ));
1092        }
1093
1094        Ok(())
1095    }
1096}
1097
1098/// Handle for hot-reloading server configuration without restart.
1099///
1100/// Obtained via [`McpServerConfig::on_reload_ready`].
1101/// All swap operations are lock-free and wait-free -- in-flight requests
1102/// finish with the old values while new requests see the update immediately.
1103#[allow(
1104    missing_debug_implementations,
1105    reason = "contains Arc<AuthState> with non-Debug fields"
1106)]
1107pub struct ReloadHandle {
1108    auth: Option<Arc<AuthState>>,
1109    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
1110    crl_set: Option<Arc<CrlSet>>,
1111}
1112
1113impl ReloadHandle {
1114    /// Atomically replace the API key list used by the auth middleware.
1115    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
1116        if let Some(ref auth) = self.auth {
1117            auth.reload_keys(keys);
1118        }
1119    }
1120
1121    /// Atomically replace the RBAC policy used by the RBAC middleware.
1122    pub fn reload_rbac(&self, policy: RbacPolicy) {
1123        if let Some(ref rbac) = self.rbac {
1124            rbac.store(Arc::new(policy));
1125            tracing::info!("RBAC policy reloaded");
1126        }
1127    }
1128
1129    /// Force an immediate refresh of all cached mTLS CRLs.
1130    ///
1131    /// # Errors
1132    ///
1133    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
1134    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
1135        let Some(ref crl_set) = self.crl_set else {
1136            return Err(McpxError::Config(
1137                "CRL refresh requested but mTLS CRL support is not configured".into(),
1138            ));
1139        };
1140
1141        crl_set.force_refresh().await
1142    }
1143}
1144
1145/// Generic MCP HTTP server.
1146///
1147/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
1148/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
1149/// with TLS (rustls). Optionally supports mTLS client certificate auth.
1150///
1151/// # Errors
1152///
1153/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
1154/// or the server fails.
1155// NOTE: cognitive complexity reduced from 111/25 to 83/25 by
1156// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
1157// Remaining flow is a linear router builder: middleware layering, feature-
1158// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
1159// would require threading many `&mut Router` helpers and hurt readability
1160// of the layer order (which is security-relevant and must stay visible).
1161#[allow(
1162    clippy::too_many_lines,
1163    clippy::cognitive_complexity,
1164    reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
1165)]
1166/// Internal bundle of values produced by [`build_app_router`] and
1167/// consumed by [`serve`] / [`serve_with_listener`] when driving the
1168/// HTTP listener.
1169struct AppRunParams {
1170    /// TLS cert/key paths when TLS is configured.
1171    tls_paths: Option<(PathBuf, PathBuf)>,
1172    /// Per-handshake deadline on the TLS accept path.
1173    tls_handshake_timeout: Duration,
1174    /// Cap on concurrently in-flight TLS handshakes.
1175    max_concurrent_tls_handshakes: usize,
1176    /// mTLS configuration when mutual-TLS auth is enabled.
1177    mtls_config: Option<MtlsConfig>,
1178    /// Graceful shutdown drain window.
1179    shutdown_timeout: Duration,
1180    /// Shared auth state used by hot-reload callbacks.
1181    auth_state: Option<Arc<AuthState>>,
1182    /// Hot-reloadable RBAC state used by reload callbacks.
1183    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1184    /// Optional callback that receives the final [`ReloadHandle`].
1185    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1186    /// Server-internal cancellation token. Cancelled by [`run_server`]
1187    /// once the shutdown trigger fires (so rmcp's child token also
1188    /// fires, terminating in-flight MCP sessions).
1189    ct: CancellationToken,
1190    /// `"http"` or `"https"` -- used only for boot-time logging.
1191    scheme: &'static str,
1192    /// Server name -- used only for boot-time logging.
1193    name: String,
1194}
1195
1196/// Build the full application axum [`axum::Router`] (MCP route +
1197/// middleware stack + admin + OAuth + health endpoints + security
1198/// headers + CORS + compression + concurrency limit + origin check)
1199/// and the [`AppRunParams`] needed to drive it.
1200///
1201/// This is the shared core of [`serve`] and [`serve_with_listener`].
1202/// It performs *no* network I/O: callers are responsible for binding
1203/// (or accepting a pre-bound) [`TcpListener`] and invoking
1204/// [`run_server`].
1205#[allow(
1206    clippy::cognitive_complexity,
1207    reason = "router assembly is intrinsically sequential; splitting harms readability"
1208)]
1209#[allow(
1210    deprecated,
1211    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
1212)]
1213fn build_app_router<H, F>(
1214    mut config: McpServerConfig,
1215    handler_factory: F,
1216) -> anyhow::Result<(axum::Router, AppRunParams)>
1217where
1218    H: ServerHandler + 'static,
1219    F: Fn() -> H + Send + Sync + Clone + 'static,
1220{
1221    let ct = CancellationToken::new();
1222
1223    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
1224    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
1225
1226    let mcp_service = StreamableHttpService::new(
1227        move || Ok(handler_factory()),
1228        {
1229            let mut mgr = LocalSessionManager::default();
1230            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
1231            mgr.into()
1232        },
1233        StreamableHttpServerConfig::default()
1234            .with_allowed_hosts(allowed_hosts)
1235            .with_sse_keep_alive(Some(config.sse_keep_alive))
1236            .with_cancellation_token(ct.child_token()),
1237    );
1238
1239    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
1240    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1241
1242    // Build auth state eagerly when auth is configured so we can wire both
1243    // the auth middleware *and* the optional admin router against the same
1244    // state. The middleware itself is installed further down in layer order.
1245    let auth_state: Option<Arc<AuthState>> = match config.auth {
1246        Some(ref auth_config) if auth_config.enabled => {
1247            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1248            let pre_auth_limiter = auth_config
1249                .rate_limit
1250                .as_ref()
1251                .map(crate::auth::build_pre_auth_limiter);
1252
1253            #[cfg(feature = "oauth")]
1254            let jwks_cache = auth_config
1255                .oauth
1256                .as_ref()
1257                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1258                .transpose()
1259                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1260
1261            Some(Arc::new(AuthState {
1262                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1263                rate_limiter,
1264                pre_auth_limiter,
1265                #[cfg(feature = "oauth")]
1266                jwks_cache,
1267                seen_identities: crate::auth::SeenIdentitySet::new(),
1268                counters: crate::auth::AuthCounters::default(),
1269            }))
1270        }
1271        _ => None,
1272    };
1273
1274    // Build the RBAC policy swap early so the admin router and the later
1275    // RBAC middleware layer share the same hot-reloadable state.
1276    let rbac_swap = Arc::new(ArcSwap::new(
1277        config
1278            .rbac
1279            .clone()
1280            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1281    ));
1282
1283    // Optional /admin/* diagnostic routes. Merged BEFORE the
1284    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
1285    if config.admin_enabled {
1286        let Some(ref auth_state_ref) = auth_state else {
1287            return Err(anyhow::anyhow!(
1288                "admin_enabled=true requires auth to be configured and enabled"
1289            ));
1290        };
1291        let admin_state = crate::admin::AdminState {
1292            started_at: std::time::Instant::now(),
1293            name: config.name.clone(),
1294            version: config.version.clone(),
1295            auth: Some(Arc::clone(auth_state_ref)),
1296            rbac: Arc::clone(&rbac_swap),
1297        };
1298        let admin_cfg = crate::admin::AdminConfig {
1299            role: config.admin_role.clone(),
1300        };
1301        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1302        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1303    }
1304
1305    // ----- Middleware order (CRITICAL: read carefully) ------------------
1306    //
1307    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
1308    // added is the OUTERMOST (runs first on a request). To achieve a
1309    // request-time flow of:
1310    //
1311    //   body-limit -> timeout -> auth -> rbac -> handler
1312    //
1313    // we add layers in the REVERSE order:
1314    //
1315    //   1. RBAC               (innermost, runs last before handler)
1316    //   2. auth               (parses identity, sets extension for RBAC)
1317    //   3. timeout            (bounds total request time)
1318    //   4. body-limit         (outermost on /mcp; caps payload before
1319    //                          anything else reads/buffers it)
1320    //
1321    // Origin validation is installed on the OUTER router (after the
1322    // /mcp router is merged in), so it also protects /healthz, /readyz,
1323    // /version, and any OAuth proxy endpoints.
1324    //
1325    // Rationale:
1326    // - Body-limit must be outermost on /mcp so RBAC (which reads the
1327    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
1328    // - Auth must run before RBAC because RBAC consumes
1329    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
1330    //   policy.
1331    // - Origin runs before auth so we reject cross-origin requests
1332    //   without spending Argon2 cycles on unauthenticated callers.
1333
1334    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
1335    // Always installed: even when RBAC is disabled, tool rate limiting may
1336    // be active (MCP spec: servers MUST rate limit tool invocations).
1337    {
1338        let tool_limiter: Option<Arc<ToolRateLimiter>> = config
1339            .tool_rate_limit
1340            .map(|per_minute| build_tool_rate_limiter(per_minute, config.tool_rate_limit_burst));
1341
1342        if rbac_swap.load().is_enabled() {
1343            tracing::info!("RBAC enforcement enabled on /mcp");
1344        }
1345        if let Some(limit) = config.tool_rate_limit {
1346            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1347        }
1348
1349        let rbac_for_mw = Arc::clone(&rbac_swap);
1350        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1351            let p = rbac_for_mw.load_full();
1352            let tl = tool_limiter.clone();
1353            rbac_middleware(p, tl, req, next)
1354        }));
1355    }
1356
1357    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
1358    if let Some(ref auth_config) = config.auth
1359        && auth_config.enabled
1360    {
1361        let Some(ref state) = auth_state else {
1362            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1363        };
1364
1365        let methods: Vec<&str> = [
1366            auth_config.mtls.is_some().then_some("mTLS"),
1367            (!auth_config.api_keys.is_empty()).then_some("bearer"),
1368            #[cfg(feature = "oauth")]
1369            auth_config.oauth.is_some().then_some("oauth-jwt"),
1370        ]
1371        .into_iter()
1372        .flatten()
1373        .collect();
1374
1375        tracing::info!(
1376            methods = %methods.join(", "),
1377            api_keys = auth_config.api_keys.len(),
1378            "auth enabled on /mcp"
1379        );
1380
1381        let state_for_mw = Arc::clone(state);
1382        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1383            let s = Arc::clone(&state_for_mw);
1384            auth_middleware(s, req, next)
1385        }));
1386    }
1387
1388    // [3] Request timeout (returns 408 on expiry). Bounds total request
1389    // duration including auth + handler.
1390    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1391        axum::http::StatusCode::REQUEST_TIMEOUT,
1392        config.request_timeout,
1393    ));
1394
1395    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
1396    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
1397    // attempts to buffer or parse the body.
1398    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1399        config.max_request_body,
1400    ));
1401
1402    // Compute the effective allowed-origins list for the outer
1403    // origin-check layer (installed on the merged router below). When
1404    // `allowed_origins` is empty but `public_url` is set, auto-derive
1405    // the origin from the public URL so MCP clients (e.g. Claude Code)
1406    // that send `Origin: <server-url>` are accepted without explicit
1407    // config.
1408    let mut effective_origins = config.allowed_origins.clone();
1409    if effective_origins.is_empty()
1410        && let Some(ref url) = config.public_url
1411    {
1412        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1413        // Strip any path/query from the public URL. Offsets come from
1414        // `find`, so they are char-boundary-aligned; `get(..)` keeps that
1415        // machine-checked (a violation degrades to an empty slice).
1416        if let Some(scheme_end) = url.find("://") {
1417            let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1418            let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1419            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1420            let host = after_scheme.get(..host_end).unwrap_or_default();
1421            let origin = format!("{scheme_with_sep}{host}");
1422            tracing::info!(
1423                %origin,
1424                "auto-derived allowed origin from public_url"
1425            );
1426            effective_origins.push(origin);
1427        }
1428    }
1429    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1430    let cors_origins = Arc::clone(&allowed_origins);
1431    let log_request_headers = config.log_request_headers;
1432
1433    let readyz_route = if let Some(check) = config.readiness_check.take() {
1434        axum::routing::get(move || readyz(Arc::clone(&check)))
1435    } else {
1436        axum::routing::get(healthz)
1437    };
1438
1439    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1440    let mut router = axum::Router::new()
1441        .route("/healthz", axum::routing::get(healthz))
1442        .route("/readyz", readyz_route)
1443        .route(
1444            "/version",
1445            axum::routing::get({
1446                // Pre-serialize the version payload once at router-build
1447                // time. The handler then serves a cheap `Arc::clone` of the
1448                // immutable bytes per request, avoiding `serde_json::Value`
1449                // allocation + serialization on every `/version` hit.
1450                let payload_bytes: Arc<[u8]> =
1451                    serialize_version_payload(&config.name, &config.version);
1452                move || {
1453                    let p = Arc::clone(&payload_bytes);
1454                    async move {
1455                        (
1456                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1457                            p.to_vec(),
1458                        )
1459                    }
1460                }
1461            }),
1462        )
1463        .merge(mcp_router);
1464
1465    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1466    // When configured, wrap them — and only them — in the per-IP rate
1467    // limiter BEFORE merging: axum layers wrap only the routes already
1468    // present on the sub-router, so the limiter can never leak onto
1469    // `/mcp`, health, admin, or OAuth endpoints, while top-level layers
1470    // (origin check, peer-address normalization, ...) still run first.
1471    if let Some(extra) = config.extra_router.take() {
1472        let extra = match config.extra_route_rate_limit {
1473            Some(per_minute) => {
1474                let limiter =
1475                    build_extra_route_rate_limiter(per_minute, config.extra_route_rate_limit_burst);
1476                tracing::info!(per_minute, "extra-route per-IP rate limit enabled");
1477                extra.layer(axum::middleware::from_fn(move |req, next| {
1478                    let l = Arc::clone(&limiter);
1479                    extra_route_rate_limit_middleware(l, req, next)
1480                }))
1481            }
1482            None => extra,
1483        };
1484        router = router.merge(extra);
1485    }
1486
1487    // RFC 9728: Protected Resource Metadata endpoint.
1488    // When OAuth is configured, serve full metadata with authorization_servers.
1489    // Otherwise, serve a minimal document with just the resource URL and no
1490    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1491    // that the server exists but does NOT require OAuth authentication,
1492    // preventing them from gating the connection behind a broken auth flow.
1493    let server_url = if let Some(ref url) = config.public_url {
1494        url.trim_end_matches('/').to_owned()
1495    } else {
1496        let prm_scheme = if config.tls_cert_path.is_some() {
1497            "https"
1498        } else {
1499            "http"
1500        };
1501        format!("{prm_scheme}://{}", config.bind_addr)
1502    };
1503    let resource_url = format!("{server_url}/mcp");
1504
1505    #[cfg(feature = "oauth")]
1506    let prm_metadata = if let Some(ref auth_config) = config.auth
1507        && let Some(ref oauth_config) = auth_config.oauth
1508    {
1509        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1510    } else {
1511        serde_json::json!({ "resource": resource_url })
1512    };
1513    #[cfg(not(feature = "oauth"))]
1514    let prm_metadata = serde_json::json!({ "resource": resource_url });
1515
1516    router = router.route(
1517        "/.well-known/oauth-protected-resource",
1518        axum::routing::get(move || {
1519            let m = prm_metadata.clone();
1520            async move { axum::Json(m) }
1521        }),
1522    );
1523
1524    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1525    // /authorize, /token, /register, and authorization server metadata so
1526    // MCP clients can perform Authorization Code + PKCE against the upstream
1527    // IdP (e.g. Keycloak) transparently.
1528    #[cfg(feature = "oauth")]
1529    if let Some(ref auth_config) = config.auth
1530        && let Some(ref oauth_config) = auth_config.oauth
1531        && oauth_config.proxy.is_some()
1532    {
1533        router =
1534            install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1535    }
1536
1537    // OWASP security response headers (applied to all responses).
1538    // HSTS is conditional on TLS being configured.
1539    let is_tls = config.tls_cert_path.is_some();
1540    let security_headers_cfg = Arc::new(config.security_headers.clone());
1541    router = router.layer(axum::middleware::from_fn(move |req, next| {
1542        let cfg = Arc::clone(&security_headers_cfg);
1543        security_headers_middleware(is_tls, cfg, req, next)
1544    }));
1545
1546    // CORS preflight layer (required for browser-based MCP clients).
1547    // Uses the same effective origins as the origin check middleware
1548    // (including auto-derived origin from public_url).
1549    if !cors_origins.is_empty() {
1550        let cors = tower_http::cors::CorsLayer::new()
1551            .allow_origin(
1552                cors_origins
1553                    .iter()
1554                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1555                    .collect::<Vec<_>>(),
1556            )
1557            .allow_methods([
1558                axum::http::Method::GET,
1559                axum::http::Method::POST,
1560                axum::http::Method::OPTIONS,
1561            ])
1562            .allow_headers([
1563                axum::http::header::CONTENT_TYPE,
1564                axum::http::header::AUTHORIZATION,
1565            ]);
1566        router = router.layer(cors);
1567    }
1568
1569    // Optional response compression (gzip + brotli). Skips small bodies
1570    // to avoid overhead. Applied after CORS so preflight responses remain
1571    // uncompressed.
1572    if config.compression_enabled {
1573        use tower_http::compression::Predicate as _;
1574        let predicate = tower_http::compression::DefaultPredicate::new().and(
1575            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1576        );
1577        router = router.layer(
1578            tower_http::compression::CompressionLayer::new()
1579                .gzip(true)
1580                .br(true)
1581                .compress_when(predicate),
1582        );
1583        tracing::info!(
1584            min_size = config.compression_min_size,
1585            "response compression enabled (gzip, br)"
1586        );
1587    }
1588
1589    // Optional global concurrency cap. `load_shed` converts the
1590    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1591    if let Some(max) = config.max_concurrent_requests {
1592        let overload_handler = tower::ServiceBuilder::new()
1593            .layer(axum::error_handling::HandleErrorLayer::new(
1594                |_err: tower::BoxError| async {
1595                    (
1596                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1597                        axum::Json(serde_json::json!({
1598                            "error": "overloaded",
1599                            "error_description": "server is at capacity, retry later"
1600                        })),
1601                    )
1602                },
1603            ))
1604            .layer(tower::load_shed::LoadShedLayer::new())
1605            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1606        router = router.layer(overload_handler);
1607        tracing::info!(max, "global concurrency limit enabled");
1608    }
1609
1610    // JSON fallback for unmatched routes. Without this, axum returns
1611    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1612    // when they probe OAuth endpoints like /authorize or /token.
1613    router = router.fallback(|| async {
1614        (
1615            axum::http::StatusCode::NOT_FOUND,
1616            axum::Json(serde_json::json!({
1617                "error": "not_found",
1618                "error_description": "The requested endpoint does not exist"
1619            })),
1620        )
1621    });
1622
1623    // Prometheus metrics: recording middleware + separate listener.
1624    #[cfg(feature = "metrics")]
1625    if config.metrics_enabled {
1626        let metrics = Arc::new(
1627            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1628        );
1629        let m = Arc::clone(&metrics);
1630        router = router.layer(axum::middleware::from_fn(
1631            move |req: Request<Body>, next: Next| {
1632                let m = Arc::clone(&m);
1633                metrics_middleware(m, req, next)
1634            },
1635        ));
1636        let metrics_bind = config.metrics_bind.clone();
1637        let metrics_shutdown = ct.clone();
1638        tokio::spawn(async move {
1639            if let Err(e) =
1640                crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1641            {
1642                tracing::error!("metrics listener failed: {e}");
1643            }
1644        });
1645    }
1646
1647    // Peer-address normalization. Mirrors the TLS branch's peer address
1648    // into `ConnectInfo<SocketAddr>` and exposes the framework-owned
1649    // `PeerAddr` extension on both listener branches, so ALL routes on
1650    // the merged router (`/mcp`, `/healthz`, OAuth proxy endpoints,
1651    // admin endpoints, extra_router, ...) and all inner middleware see a
1652    // uniform peer-address contract regardless of TLS. Installed just
1653    // inside the origin check, which stays outermost by design.
1654    router = router.layer(axum::middleware::from_fn(normalize_peer_addr_middleware));
1655
1656    // Origin validation layer (MCP spec: servers MUST validate the
1657    // Origin header to prevent DNS rebinding attacks). Installed as the
1658    // OUTERMOST layer on the OUTER router so it protects ALL routes
1659    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1660    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1661    // reject cross-origin attackers without spending Argon2 cycles.
1662    //
1663    // Origin-less requests (e.g. server-to-server probes, curl, native
1664    // MCP clients) are permitted; only requests with an Origin header
1665    // that does not match `effective_origins` are rejected.
1666    router = router.layer(axum::middleware::from_fn(move |req, next| {
1667        let origins = Arc::clone(&allowed_origins);
1668        origin_check_middleware(origins, log_request_headers, req, next)
1669    }));
1670
1671    let scheme = if config.tls_cert_path.is_some() {
1672        "https"
1673    } else {
1674        "http"
1675    };
1676
1677    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1678        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1679        _ => None,
1680    };
1681    let tls_handshake_timeout = config.tls_handshake_timeout;
1682    let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1683    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1684
1685    Ok((
1686        router,
1687        AppRunParams {
1688            tls_paths,
1689            tls_handshake_timeout,
1690            max_concurrent_tls_handshakes,
1691            mtls_config,
1692            shutdown_timeout: config.shutdown_timeout,
1693            auth_state,
1694            rbac_swap,
1695            on_reload_ready: config.on_reload_ready.take(),
1696            ct,
1697            scheme,
1698            name: config.name.clone(),
1699        },
1700    ))
1701}
1702
1703/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1704/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1705///
1706/// This is the standard entry point for production deployments. For
1707/// deterministic shutdown control (e.g. integration tests), see
1708/// [`serve_with_listener`].
1709///
1710/// The configuration must be validated first via
1711/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1712/// token. This typestate guarantees, at compile time, that the server
1713/// never starts with an invalid configuration.
1714///
1715/// # Errors
1716///
1717/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1718/// fails, or if the underlying axum server returns an error.
1719pub async fn serve<H, F>(
1720    config: Validated<McpServerConfig>,
1721    handler_factory: F,
1722) -> Result<(), McpxError>
1723where
1724    H: ServerHandler + 'static,
1725    F: Fn() -> H + Send + Sync + Clone + 'static,
1726{
1727    let config = config.into_inner();
1728    #[allow(
1729        deprecated,
1730        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1731    )]
1732    let bind_addr = config.bind_addr.clone();
1733    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1734
1735    let listener = TcpListener::bind(&bind_addr)
1736        .await
1737        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1738    log_listening(&params.name, params.scheme, &bind_addr);
1739
1740    run_server(
1741        router,
1742        listener,
1743        params.tls_paths,
1744        params.tls_handshake_timeout,
1745        params.max_concurrent_tls_handshakes,
1746        params.mtls_config,
1747        params.shutdown_timeout,
1748        params.auth_state,
1749        params.rbac_swap,
1750        params.on_reload_ready,
1751        params.ct,
1752    )
1753    .await
1754    .map_err(anyhow_to_startup)
1755}
1756
1757/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1758/// readiness signalling and external shutdown control.
1759///
1760/// This variant is intended for **deterministic integration tests** and
1761/// for embedders that need to bind the listening socket themselves
1762/// (e.g. systemd socket activation). Compared to [`serve`]:
1763///
1764/// * The caller passes a `TcpListener` that is already bound. This
1765///   eliminates the bind race in tests that previously required
1766///   poll-the-`/healthz`-loop start-up detection.
1767/// * `ready_tx`, when `Some`, receives the socket's
1768///   [`SocketAddr`] *after* the router is built and immediately before
1769///   the server starts accepting connections. Tests can `await` the
1770///   matching `oneshot::Receiver` to know exactly when it is safe to
1771///   issue requests.
1772/// * `shutdown`, when `Some`, gives the caller a
1773///   [`CancellationToken`] that triggers the same graceful-shutdown
1774///   path as a real OS signal. This avoids cross-platform issues with
1775///   sending real `SIGTERM` from tests on Windows.
1776///
1777/// All three optional parameters degrade gracefully: if `ready_tx` is
1778/// `None`, no signal is sent; if `shutdown` is `None`, the server only
1779/// stops on an OS signal (just like [`serve`]).
1780///
1781/// # Errors
1782///
1783/// Returns [`McpxError::Startup`] if router construction fails, if reading
1784/// the listener's `local_addr()` fails, or if the underlying axum
1785/// server returns an error.
1786pub async fn serve_with_listener<H, F>(
1787    listener: TcpListener,
1788    config: Validated<McpServerConfig>,
1789    handler_factory: F,
1790    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1791    shutdown: Option<CancellationToken>,
1792) -> Result<(), McpxError>
1793where
1794    H: ServerHandler + 'static,
1795    F: Fn() -> H + Send + Sync + Clone + 'static,
1796{
1797    let config = config.into_inner();
1798    let local_addr = listener
1799        .local_addr()
1800        .map_err(|e| io_to_startup("listener.local_addr", e))?;
1801    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1802
1803    log_listening(&params.name, params.scheme, &local_addr.to_string());
1804
1805    // Forward external shutdown into the server-internal cancellation
1806    // token so `run_server`'s shutdown trigger picks it up alongside
1807    // any real OS signal.
1808    if let Some(external) = shutdown {
1809        let internal = params.ct.clone();
1810        tokio::spawn(async move {
1811            external.cancelled().await;
1812            internal.cancel();
1813        });
1814    }
1815
1816    // Signal readiness *after* the router is fully built and external
1817    // shutdown is wired, but *before* run_server takes ownership of
1818    // the listener. The receiver can immediately issue requests.
1819    if let Some(tx) = ready_tx {
1820        // Receiver may have been dropped (test gave up). That's fine.
1821        let _ = tx.send(local_addr);
1822    }
1823
1824    run_server(
1825        router,
1826        listener,
1827        params.tls_paths,
1828        params.tls_handshake_timeout,
1829        params.max_concurrent_tls_handshakes,
1830        params.mtls_config,
1831        params.shutdown_timeout,
1832        params.auth_state,
1833        params.rbac_swap,
1834        params.on_reload_ready,
1835        params.ct,
1836    )
1837    .await
1838    .map_err(anyhow_to_startup)
1839}
1840
1841/// Emit the standard "listening on …" log lines used by both
1842/// [`serve`] and [`serve_with_listener`].
1843#[allow(
1844    clippy::cognitive_complexity,
1845    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1846)]
1847fn log_listening(name: &str, scheme: &str, addr: &str) {
1848    tracing::info!("{name} listening on {addr}");
1849    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
1850    tracing::info!("  Health check: {scheme}://{addr}/healthz");
1851    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
1852}
1853
1854/// Drive the chosen axum server variant (TLS or plain) with a graceful
1855/// shutdown window. Consumes the router and listener.
1856///
1857/// # Shutdown semantics
1858///
1859/// A single shutdown trigger (the FIRST of: OS signal via
1860/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
1861///
1862/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
1863///    accepting new connections and waits for in-flight requests to
1864///    drain;
1865/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
1866///    drainage exceeds `shutdown_timeout`.
1867///
1868/// Previously this function awaited `shutdown_signal()` independently
1869/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
1870/// resolves once per future and consumes one signal, the force-exit
1871/// timer was tied to a SECOND signal (a second SIGTERM the operator
1872/// would never send). Under a single SIGTERM the graceful drain could
1873/// hang indefinitely. The current implementation derives both branches
1874/// from a single shared trigger so the timeout race is anchored to the
1875/// FIRST (and only) signal.
1876#[allow(
1877    clippy::too_many_arguments,
1878    clippy::cognitive_complexity,
1879    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1880)]
1881async fn run_server(
1882    router: axum::Router,
1883    listener: TcpListener,
1884    tls_paths: Option<(PathBuf, PathBuf)>,
1885    tls_handshake_timeout: Duration,
1886    max_concurrent_tls_handshakes: usize,
1887    mtls_config: Option<MtlsConfig>,
1888    shutdown_timeout: Duration,
1889    auth_state: Option<Arc<AuthState>>,
1890    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1891    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1892    ct: CancellationToken,
1893) -> anyhow::Result<()> {
1894    // `shutdown_trigger` fires when the FIRST source resolves: either
1895    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
1896    // (which the test harness uses for deterministic shutdown).
1897    let shutdown_trigger = CancellationToken::new();
1898    {
1899        let trigger = shutdown_trigger.clone();
1900        let parent = ct.clone();
1901        tokio::spawn(async move {
1902            tokio::select! {
1903                () = shutdown_signal() => {}
1904                () = parent.cancelled() => {}
1905            }
1906            trigger.cancel();
1907        });
1908    }
1909
1910    let graceful = {
1911        let trigger = shutdown_trigger.clone();
1912        let ct = ct.clone();
1913        async move {
1914            trigger.cancelled().await;
1915            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1916            ct.cancel();
1917        }
1918    };
1919
1920    let force_exit_timer = {
1921        let trigger = shutdown_trigger.clone();
1922        async move {
1923            trigger.cancelled().await;
1924            tokio::time::sleep(shutdown_timeout).await;
1925        }
1926    };
1927
1928    if let Some((cert_path, key_path)) = tls_paths {
1929        let crl_set = if let Some(mtls) = mtls_config.as_ref()
1930            && mtls.crl_enabled
1931        {
1932            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1933            let (crl_set, discover_rx) =
1934                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1935                    .await
1936                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1937            tokio::spawn(mtls_revocation::run_crl_refresher(
1938                Arc::clone(&crl_set),
1939                discover_rx,
1940                ct.clone(),
1941            ));
1942            Some(crl_set)
1943        } else {
1944            None
1945        };
1946
1947        if let Some(cb) = on_reload_ready.take() {
1948            cb(ReloadHandle {
1949                auth: auth_state.clone(),
1950                rbac: Some(Arc::clone(&rbac_swap)),
1951                crl_set: crl_set.clone(),
1952            });
1953        }
1954
1955        let tls_listener = TlsListener::new(
1956            listener,
1957            &cert_path,
1958            &key_path,
1959            mtls_config.as_ref(),
1960            crl_set,
1961            tls_handshake_timeout,
1962            max_concurrent_tls_handshakes,
1963        )?;
1964        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1965        tokio::select! {
1966            result = axum::serve(tls_listener, make_svc)
1967                .with_graceful_shutdown(graceful) => { result?; }
1968            () = force_exit_timer => {
1969                tracing::warn!("shutdown timeout exceeded, forcing exit");
1970            }
1971        }
1972    } else {
1973        if let Some(cb) = on_reload_ready.take() {
1974            cb(ReloadHandle {
1975                auth: auth_state,
1976                rbac: Some(rbac_swap),
1977                crl_set: None,
1978            });
1979        }
1980
1981        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1982        tokio::select! {
1983            result = axum::serve(listener, make_svc)
1984                .with_graceful_shutdown(graceful) => { result?; }
1985            () = force_exit_timer => {
1986                tracing::warn!("shutdown timeout exceeded, forcing exit");
1987            }
1988        }
1989    }
1990
1991    Ok(())
1992}
1993
1994/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
1995/// `/register`, and authorization server metadata) on `router`. The
1996/// caller must ensure `oauth_config.proxy` is `Some`.
1997///
1998/// # Errors
1999///
2000/// Returns [`McpxError::Startup`] if the shared
2001/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
2002#[cfg(feature = "oauth")]
2003fn install_oauth_proxy_routes(
2004    router: axum::Router,
2005    server_url: &str,
2006    oauth_config: &crate::oauth::OAuthConfig,
2007    auth_state: Option<&Arc<AuthState>>,
2008) -> Result<axum::Router, McpxError> {
2009    let Some(ref proxy) = oauth_config.proxy else {
2010        return Ok(router);
2011    };
2012
2013    // Single shared HTTP client for all proxy endpoints. Cloning is
2014    // cheap (refcounted) and shares the underlying connection pool.
2015    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
2016
2017    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
2018    let router = router.route(
2019        "/.well-known/oauth-authorization-server",
2020        axum::routing::get(move || {
2021            let m = asm.clone();
2022            async move { axum::Json(m) }
2023        }),
2024    );
2025
2026    let proxy_authorize = proxy.clone();
2027    let router = router.route(
2028        "/authorize",
2029        axum::routing::get(
2030            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
2031                let p = proxy_authorize.clone();
2032                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
2033            },
2034        ),
2035    );
2036
2037    let proxy_token = proxy.clone();
2038    let token_http = http.clone();
2039    let router = router.route(
2040        "/token",
2041        axum::routing::post(move |body: String| {
2042            let p = proxy_token.clone();
2043            let h = token_http.clone();
2044            async move { crate::oauth::handle_token(&h, &p, &body).await }
2045        })
2046        .layer(axum::middleware::from_fn(
2047            oauth_token_cache_headers_middleware,
2048        )),
2049    );
2050
2051    let proxy_register = proxy.clone();
2052    let router = router.route(
2053        "/register",
2054        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
2055            let p = proxy_register;
2056            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
2057        })
2058        .layer(axum::middleware::from_fn(
2059            oauth_token_cache_headers_middleware,
2060        )),
2061    );
2062
2063    let admin_routes_enabled = proxy.expose_admin_endpoints
2064        && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
2065    if proxy.expose_admin_endpoints
2066        && !proxy.require_auth_on_admin_endpoints
2067        && proxy.allow_unauthenticated_admin_endpoints
2068    {
2069        // M3 escape-hatch in effect: validate() let this through because
2070        // the operator explicitly opted in. Surface it loudly at startup
2071        // so the choice is auditable in logs.
2072        tracing::warn!(
2073            "OAuth introspect/revoke endpoints are unauthenticated by explicit \
2074             allow_unauthenticated_admin_endpoints opt-out; ensure an \
2075             authenticated reverse proxy fronts these routes"
2076        );
2077    }
2078
2079    let admin_router = if admin_routes_enabled {
2080        build_oauth_admin_router(proxy, http, auth_state)?
2081    } else {
2082        axum::Router::new()
2083    };
2084
2085    let router = router.merge(admin_router);
2086
2087    tracing::info!(
2088        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
2089        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
2090        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
2091    );
2092    Ok(router)
2093}
2094
2095/// Build the optional `/introspect` + `/revoke` admin sub-router.
2096///
2097/// Layered with [`oauth_token_cache_headers_middleware`] so RFC 6749 §5.1
2098/// / RFC 6750 §5.4 cache headers are emitted, and conditionally with the
2099/// auth middleware when `proxy.require_auth_on_admin_endpoints` is set.
2100#[cfg(feature = "oauth")]
2101fn build_oauth_admin_router(
2102    proxy: &crate::oauth::OAuthProxyConfig,
2103    http: crate::oauth::OauthHttpClient,
2104    auth_state: Option<&Arc<AuthState>>,
2105) -> Result<axum::Router, McpxError> {
2106    let mut admin_router = axum::Router::new();
2107    if proxy.introspection_url.is_some() {
2108        let proxy_introspect = proxy.clone();
2109        let introspect_http = http.clone();
2110        admin_router = admin_router.route(
2111            "/introspect",
2112            axum::routing::post(move |body: String| {
2113                let p = proxy_introspect.clone();
2114                let h = introspect_http.clone();
2115                async move { crate::oauth::handle_introspect(&h, &p, &body).await }
2116            }),
2117        );
2118    }
2119    if proxy.revocation_url.is_some() {
2120        let proxy_revoke = proxy.clone();
2121        let revoke_http = http;
2122        admin_router = admin_router.route(
2123            "/revoke",
2124            axum::routing::post(move |body: String| {
2125                let p = proxy_revoke.clone();
2126                let h = revoke_http.clone();
2127                async move { crate::oauth::handle_revoke(&h, &p, &body).await }
2128            }),
2129        );
2130    }
2131
2132    let admin_router = admin_router.layer(axum::middleware::from_fn(
2133        oauth_token_cache_headers_middleware,
2134    ));
2135
2136    if proxy.require_auth_on_admin_endpoints {
2137        let Some(state) = auth_state else {
2138            return Err(McpxError::Startup(
2139                "oauth proxy admin endpoints require auth state".into(),
2140            ));
2141        };
2142        let state_for_mw = Arc::clone(state);
2143        Ok(
2144            admin_router.layer(axum::middleware::from_fn(move |req, next| {
2145                let s = Arc::clone(&state_for_mw);
2146                auth_middleware(s, req, next)
2147            })),
2148        )
2149    } else {
2150        Ok(admin_router)
2151    }
2152}
2153
2154/// Build the host allow-list for rmcp's DNS rebinding protection.
2155///
2156/// Includes loopback hosts by default, then augments with host/authority
2157/// derived from `public_url` and the server bind address.
2158fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
2159    let mut hosts = vec![
2160        "localhost".to_owned(),
2161        "127.0.0.1".to_owned(),
2162        "::1".to_owned(),
2163    ];
2164
2165    if let Some(url) = public_url
2166        && let Ok(uri) = url.parse::<axum::http::Uri>()
2167        && let Some(authority) = uri.authority()
2168    {
2169        let host = authority.host().to_owned();
2170        if !hosts.iter().any(|h| h == &host) {
2171            hosts.push(host);
2172        }
2173
2174        let authority = authority.as_str().to_owned();
2175        if !hosts.iter().any(|h| h == &authority) {
2176            hosts.push(authority);
2177        }
2178    }
2179
2180    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
2181        && let Some(authority) = uri.authority()
2182    {
2183        let host = authority.host().to_owned();
2184        if !hosts.iter().any(|h| h == &host) {
2185            hosts.push(host);
2186        }
2187
2188        let authority = authority.as_str().to_owned();
2189        if !hosts.iter().any(|h| h == &authority) {
2190            hosts.push(authority);
2191        }
2192    }
2193
2194    hosts
2195}
2196
2197// - TLS support -
2198
2199/// Implement axum's `Connected` trait for `TlsConnInfo` so that
2200/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
2201/// over our custom `TlsListener`.
2202///
2203/// The identity is read directly from the wrapping
2204/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
2205/// between the TLS connection and its mTLS identity. This eliminates the
2206/// previous shared-map approach which was vulnerable to ephemeral-port
2207/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
2208/// pair could alias a stale entry).
2209impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
2210    for TlsConnInfo
2211{
2212    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
2213        let addr = *target.remote_addr();
2214        let identity = target.io().identity().cloned();
2215        TlsConnInfo::new(addr, identity)
2216    }
2217}
2218
2219/// Default per-handshake deadline on the TLS accept path. Prevents idle
2220/// or slow-loris connections from pinning handshake worker tasks (and
2221/// their semaphore permits) indefinitely.
2222///
2223/// Configurable since 1.9.0 via
2224/// [`McpServerConfig::with_tls_handshake_timeout`].
2225const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
2226
2227/// Default upper bound on concurrently in-flight TLS handshakes. When
2228/// saturated, the acceptor task stops pulling new connections from the
2229/// kernel backlog (backpressure) instead of accepting and dropping them
2230/// in user space.
2231///
2232/// Configurable since 1.9.0 via
2233/// [`McpServerConfig::with_max_concurrent_tls_handshakes`].
2234const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
2235
2236/// Capacity of the completed-handshake queue between the acceptor task and
2237/// `axum::serve`'s `accept()` loop. Handshake workers block on `send` when
2238/// the queue is full, so a slow accept loop back-pressures handshakes
2239/// rather than buffering completed connections unboundedly.
2240const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
2241
2242/// A TLS-wrapping listener that implements axum's `Listener` trait.
2243///
2244/// TCP accepts and TLS handshakes run on a dedicated background task: each
2245/// accepted connection's handshake is spawned onto its own worker task,
2246/// bounded by a configurable concurrent-handshake cap (default
2247/// [`DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES`]) and a per-handshake timeout
2248/// (default [`DEFAULT_TLS_HANDSHAKE_TIMEOUT`]). A slow or idle client
2249/// therefore cannot stall other connections behind a serialized inline
2250/// handshake.
2251///
2252/// When mTLS is configured, client certificates are verified against the
2253/// configured CA and the client identity is extracted at handshake time.
2254/// The extracted identity is bound to the connection itself via the
2255/// returned [`AuthenticatedTlsStream`], so it is impossible for an
2256/// unrelated connection to observe it.
2257struct TlsListener {
2258    /// Bound address, captured eagerly before the `TcpListener` moves into
2259    /// the acceptor task.
2260    local_addr: SocketAddr,
2261    /// Completed handshakes produced by the acceptor task's workers.
2262    rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2263    /// Background task driving TCP accepts and concurrent TLS handshakes.
2264    /// Aborted on drop so the listener releases its port deterministically.
2265    acceptor_task: tokio::task::JoinHandle<()>,
2266}
2267
2268impl TlsListener {
2269    fn new(
2270        inner: TcpListener,
2271        cert_path: &Path,
2272        key_path: &Path,
2273        mtls_config: Option<&MtlsConfig>,
2274        crl_set: Option<Arc<CrlSet>>,
2275        handshake_timeout: Duration,
2276        max_concurrent_handshakes: usize,
2277    ) -> anyhow::Result<Self> {
2278        // Install the ring crypto provider (ok to call multiple times).
2279        rustls::crypto::ring::default_provider()
2280            .install_default()
2281            .ok();
2282
2283        let certs = load_certs(cert_path)?;
2284        let key = load_key(key_path)?;
2285
2286        let mtls_default_role;
2287
2288        let tls_config = if let Some(mtls) = mtls_config {
2289            mtls_default_role = mtls.default_role.clone();
2290            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2291            {
2292                let Some(crl_set) = crl_set else {
2293                    return Err(anyhow::anyhow!(
2294                        "mTLS CRL verifier requested but CRL state was not initialized"
2295                    ));
2296                };
2297                Arc::new(DynamicClientCertVerifier::new(crl_set))
2298            } else {
2299                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2300                if mtls.required {
2301                    rustls::server::WebPkiClientVerifier::builder(root_store)
2302                        .build()
2303                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2304                } else {
2305                    rustls::server::WebPkiClientVerifier::builder(root_store)
2306                        .allow_unauthenticated()
2307                        .build()
2308                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2309                }
2310            };
2311
2312            tracing::info!(
2313                ca = %mtls.ca_cert_path.display(),
2314                required = mtls.required,
2315                crl_enabled = mtls.crl_enabled,
2316                "mTLS client auth configured"
2317            );
2318
2319            rustls::ServerConfig::builder_with_protocol_versions(&[
2320                &rustls::version::TLS12,
2321                &rustls::version::TLS13,
2322            ])
2323            .with_client_cert_verifier(verifier)
2324            .with_single_cert(certs, key)?
2325        } else {
2326            mtls_default_role = "viewer".to_owned();
2327            rustls::ServerConfig::builder_with_protocol_versions(&[
2328                &rustls::version::TLS12,
2329                &rustls::version::TLS13,
2330            ])
2331            .with_no_client_auth()
2332            .with_single_cert(certs, key)?
2333        };
2334
2335        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2336        tracing::info!(
2337            "TLS enabled (cert: {}, key: {})",
2338            cert_path.display(),
2339            key_path.display()
2340        );
2341        let local_addr = inner.local_addr()?;
2342        let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2343        let acceptor_task = tokio::spawn(run_tls_acceptor(
2344            inner,
2345            acceptor,
2346            mtls_default_role,
2347            tx,
2348            handshake_timeout,
2349            max_concurrent_handshakes,
2350        ));
2351        Ok(Self {
2352            local_addr,
2353            rx,
2354            acceptor_task,
2355        })
2356    }
2357
2358    /// Extract the mTLS client cert identity from a completed TLS handshake.
2359    /// Returns `None` if no client certificate was presented or if the
2360    /// certificate could not be parsed into an [`AuthIdentity`].
2361    fn extract_handshake_identity(
2362        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2363        default_role: &str,
2364        addr: SocketAddr,
2365    ) -> Option<AuthIdentity> {
2366        let (_, server_conn) = tls_stream.get_ref();
2367        let cert_der = server_conn.peer_certificates()?.first()?;
2368        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2369        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2370        Some(id)
2371    }
2372}
2373
2374/// Drive TCP accepts and concurrent TLS handshakes for [`TlsListener`].
2375///
2376/// Each accepted connection's handshake runs on its own worker task under
2377/// a permit from a `max_concurrent_handshakes`-sized semaphore and a
2378/// `handshake_timeout` deadline. Completed handshakes are pushed to `tx`;
2379/// failures and timeouts are logged at DEBUG and the connection dropped.
2380/// The loop exits when the owning [`TlsListener`] is dropped.
2381async fn run_tls_acceptor(
2382    listener: TcpListener,
2383    acceptor: tokio_rustls::TlsAcceptor,
2384    default_role: String,
2385    tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2386    handshake_timeout: Duration,
2387    max_concurrent_handshakes: usize,
2388) {
2389    let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2390    loop {
2391        // Acquire the permit BEFORE accepting: at saturation, pending
2392        // connections wait in the kernel backlog instead of being accepted
2393        // and then buffered or dropped in user space.
2394        let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2395            // The semaphore is never closed; defensive exit.
2396            return;
2397        };
2398        let (stream, addr) = match listener.accept().await {
2399            Ok(pair) => pair,
2400            Err(e) => {
2401                tracing::debug!("TCP accept error: {e}");
2402                continue;
2403            }
2404        };
2405        if tx.is_closed() {
2406            // The listener was dropped (shutdown): stop accepting.
2407            return;
2408        }
2409        let acceptor = acceptor.clone();
2410        let default_role = default_role.clone();
2411        let tx = tx.clone();
2412        tokio::spawn(async move {
2413            let _permit = permit;
2414            match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2415                Ok(Ok(tls_stream)) => {
2416                    let identity =
2417                        TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2418                    let wrapped = AuthenticatedTlsStream {
2419                        inner: tls_stream,
2420                        identity,
2421                    };
2422                    // The receiver only disappears during shutdown; discard
2423                    // the completed connection quietly rather than logging.
2424                    let _ = tx.send((wrapped, addr)).await;
2425                }
2426                Ok(Err(e)) => {
2427                    tracing::debug!("TLS handshake failed from {addr}: {e}");
2428                }
2429                Err(_elapsed) => {
2430                    tracing::debug!(
2431                        "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2432                    );
2433                }
2434            }
2435        });
2436    }
2437}
2438
2439/// A TLS stream paired with the mTLS identity extracted at handshake time.
2440///
2441/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
2442/// identity travels with the connection itself. This replaces the previous
2443/// shared `MtlsIdentities` map, eliminating the
2444/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
2445/// reuse and removing the need for an LRU eviction policy.
2446///
2447/// The wrapper is `Unpin` (its inner stream is `Unpin` because
2448/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
2449/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
2450pub(crate) struct AuthenticatedTlsStream {
2451    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2452    identity: Option<AuthIdentity>,
2453}
2454
2455impl AuthenticatedTlsStream {
2456    /// Returns the verified mTLS client identity, if any.
2457    #[must_use]
2458    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2459        self.identity.as_ref()
2460    }
2461}
2462
2463impl std::fmt::Debug for AuthenticatedTlsStream {
2464    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2465        f.debug_struct("AuthenticatedTlsStream")
2466            .field("identity", &self.identity.as_ref().map(|id| &id.name))
2467            .finish_non_exhaustive()
2468    }
2469}
2470
2471impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2472    fn poll_read(
2473        mut self: Pin<&mut Self>,
2474        cx: &mut std::task::Context<'_>,
2475        buf: &mut tokio::io::ReadBuf<'_>,
2476    ) -> std::task::Poll<std::io::Result<()>> {
2477        Pin::new(&mut self.inner).poll_read(cx, buf)
2478    }
2479}
2480
2481impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2482    fn poll_write(
2483        mut self: Pin<&mut Self>,
2484        cx: &mut std::task::Context<'_>,
2485        buf: &[u8],
2486    ) -> std::task::Poll<std::io::Result<usize>> {
2487        Pin::new(&mut self.inner).poll_write(cx, buf)
2488    }
2489
2490    fn poll_flush(
2491        mut self: Pin<&mut Self>,
2492        cx: &mut std::task::Context<'_>,
2493    ) -> std::task::Poll<std::io::Result<()>> {
2494        Pin::new(&mut self.inner).poll_flush(cx)
2495    }
2496
2497    fn poll_shutdown(
2498        mut self: Pin<&mut Self>,
2499        cx: &mut std::task::Context<'_>,
2500    ) -> std::task::Poll<std::io::Result<()>> {
2501        Pin::new(&mut self.inner).poll_shutdown(cx)
2502    }
2503
2504    fn poll_write_vectored(
2505        mut self: Pin<&mut Self>,
2506        cx: &mut std::task::Context<'_>,
2507        bufs: &[std::io::IoSlice<'_>],
2508    ) -> std::task::Poll<std::io::Result<usize>> {
2509        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2510    }
2511
2512    fn is_write_vectored(&self) -> bool {
2513        self.inner.is_write_vectored()
2514    }
2515}
2516
2517impl axum::serve::Listener for TlsListener {
2518    type Io = AuthenticatedTlsStream;
2519    type Addr = SocketAddr;
2520
2521    /// Yield the next fully-handshaken TLS connection.
2522    ///
2523    /// Cancel-safe: this is a plain `mpsc::Receiver::recv`, so cancelling
2524    /// the future (axum selects it against graceful shutdown) never loses
2525    /// a connection.
2526    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2527        if let Some(pair) = self.rx.recv().await {
2528            return pair;
2529        }
2530        // The channel only closes if the acceptor task terminated, which
2531        // means the TcpListener is gone and the OS already refuses new
2532        // connections. `Listener::accept` is infallible and panicking is
2533        // forbidden, so park forever: existing connections keep being
2534        // served and graceful shutdown still completes.
2535        tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2536        std::future::pending().await
2537    }
2538
2539    fn local_addr(&self) -> std::io::Result<Self::Addr> {
2540        Ok(self.local_addr)
2541    }
2542}
2543
2544impl Drop for TlsListener {
2545    fn drop(&mut self) {
2546        // Stop accepting immediately and release the bound port. In-flight
2547        // handshake workers notice the closed channel and exit quietly.
2548        self.acceptor_task.abort();
2549    }
2550}
2551
2552fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2553    use rustls::pki_types::pem::PemObject;
2554    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2555        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2556        .collect::<Result<_, _>>()
2557        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2558    anyhow::ensure!(
2559        !certs.is_empty(),
2560        "no certificates found in {}",
2561        path.display()
2562    );
2563    Ok(certs)
2564}
2565
2566fn load_client_auth_roots(
2567    path: &Path,
2568) -> anyhow::Result<(
2569    Vec<rustls::pki_types::CertificateDer<'static>>,
2570    Arc<RootCertStore>,
2571)> {
2572    let ca_certs = load_certs(path)?;
2573    let mut root_store = RootCertStore::empty();
2574    for cert in &ca_certs {
2575        root_store
2576            .add(cert.clone())
2577            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2578    }
2579
2580    Ok((ca_certs, Arc::new(root_store)))
2581}
2582
2583fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2584    use rustls::pki_types::pem::PemObject;
2585    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2586        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2587}
2588
2589#[allow(
2590    clippy::unused_async,
2591    reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2592)]
2593async fn healthz() -> impl IntoResponse {
2594    axum::Json(serde_json::json!({
2595        "status": "ok",
2596    }))
2597}
2598
2599/// Build the `/version` JSON payload for a given server name and version.
2600///
2601/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
2602/// read at compile time from the `RMCP_SERVER_KIT_BUILD_SHA`,
2603/// `RMCP_SERVER_KIT_BUILD_TIME`, and `RMCP_SERVER_KIT_RUSTC_VERSION` env
2604/// vars. Unset values resolve to `"unknown"`.
2605fn version_payload(name: &str, version: &str) -> serde_json::Value {
2606    serde_json::json!({
2607        "name": name,
2608        "version": version,
2609        "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2610        "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2611        "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2612        "mcpx_version": env!("CARGO_PKG_VERSION"),
2613    })
2614}
2615
2616/// Pre-serialize the `/version` payload to immutable bytes.
2617///
2618/// This is called once at router-build time so per-request handling can
2619/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
2620/// [`serde_json::Value`] on every hit.
2621///
2622/// Serialization of a flat `serde_json::Value` of static-string fields
2623/// cannot fail in practice; the fallback to `b"{}"` exists only to
2624/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
2625fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2626    let value = version_payload(name, version);
2627    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2628}
2629
2630async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2631    let status = check().await;
2632    let ready = status
2633        .get("ready")
2634        .and_then(serde_json::Value::as_bool)
2635        .unwrap_or(false);
2636    let code = if ready {
2637        axum::http::StatusCode::OK
2638    } else {
2639        axum::http::StatusCode::SERVICE_UNAVAILABLE
2640    };
2641    (code, axum::Json(status))
2642}
2643
2644/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2645///
2646/// On non-Unix platforms, only SIGINT is handled.
2647async fn shutdown_signal() {
2648    let ctrl_c = tokio::signal::ctrl_c();
2649
2650    #[cfg(unix)]
2651    {
2652        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2653            Ok(mut term) => {
2654                tokio::select! {
2655                    _ = ctrl_c => {}
2656                    _ = term.recv() => {}
2657                }
2658            }
2659            Err(e) => {
2660                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2661                ctrl_c.await.ok();
2662            }
2663        }
2664    }
2665
2666    #[cfg(not(unix))]
2667    {
2668        ctrl_c.await.ok();
2669    }
2670}
2671
2672// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2673
2674/// Middleware that validates the `Origin` header on incoming HTTP requests.
2675///
2676/// Record HTTP request metrics (method, path, status, duration).
2677#[cfg(feature = "metrics")]
2678async fn metrics_middleware(
2679    metrics: Arc<crate::metrics::McpMetrics>,
2680    req: Request<Body>,
2681    next: Next,
2682) -> axum::response::Response {
2683    let method = req.method().to_string();
2684    let path = req.uri().path().to_owned();
2685    let start = std::time::Instant::now();
2686
2687    let response = next.run(req).await;
2688
2689    let status = response.status().as_u16().to_string();
2690    let duration = start.elapsed().as_secs_f64();
2691
2692    metrics
2693        .http_requests_total
2694        .with_label_values(&[&method, &path, &status])
2695        .inc();
2696    metrics
2697        .http_request_duration_seconds
2698        .with_label_values(&[&method, &path])
2699        .observe(duration);
2700
2701    response
2702}
2703
2704/// OWASP security header hardening applied to every response.
2705///
2706/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2707/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2708/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2709/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2710/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2711///
2712/// Each header's value can be customised via [`SecurityHeadersConfig`]
2713/// on [`McpServerConfig`]. See that type for the three-state semantic
2714/// (`None` = default, `Some("")` = omit, `Some(v)` = override).
2715async fn security_headers_middleware(
2716    is_tls: bool,
2717    cfg: Arc<SecurityHeadersConfig>,
2718    req: Request<Body>,
2719    next: Next,
2720) -> axum::response::Response {
2721    use axum::http::{HeaderName, header};
2722
2723    let mut resp = next.run(req).await;
2724    let headers = resp.headers_mut();
2725
2726    // Strip server identity headers to reduce information leakage.
2727    headers.remove(header::SERVER);
2728    headers.remove(HeaderName::from_static("x-powered-by"));
2729
2730    apply_security_header(
2731        headers,
2732        header::X_CONTENT_TYPE_OPTIONS,
2733        cfg.x_content_type_options.as_deref(),
2734        "nosniff",
2735    );
2736    apply_security_header(
2737        headers,
2738        header::X_FRAME_OPTIONS,
2739        cfg.x_frame_options.as_deref(),
2740        "deny",
2741    );
2742    apply_security_header(
2743        headers,
2744        header::CACHE_CONTROL,
2745        cfg.cache_control.as_deref(),
2746        "no-store, max-age=0",
2747    );
2748    apply_security_header(
2749        headers,
2750        header::REFERRER_POLICY,
2751        cfg.referrer_policy.as_deref(),
2752        "no-referrer",
2753    );
2754    apply_security_header(
2755        headers,
2756        HeaderName::from_static("cross-origin-opener-policy"),
2757        cfg.cross_origin_opener_policy.as_deref(),
2758        "same-origin",
2759    );
2760    apply_security_header(
2761        headers,
2762        HeaderName::from_static("cross-origin-resource-policy"),
2763        cfg.cross_origin_resource_policy.as_deref(),
2764        "same-origin",
2765    );
2766    apply_security_header(
2767        headers,
2768        HeaderName::from_static("cross-origin-embedder-policy"),
2769        cfg.cross_origin_embedder_policy.as_deref(),
2770        "require-corp",
2771    );
2772    apply_security_header(
2773        headers,
2774        HeaderName::from_static("permissions-policy"),
2775        cfg.permissions_policy.as_deref(),
2776        "accelerometer=(), camera=(), geolocation=(), microphone=()",
2777    );
2778    apply_security_header(
2779        headers,
2780        HeaderName::from_static("x-permitted-cross-domain-policies"),
2781        cfg.x_permitted_cross_domain_policies.as_deref(),
2782        "none",
2783    );
2784    apply_security_header(
2785        headers,
2786        HeaderName::from_static("content-security-policy"),
2787        cfg.content_security_policy.as_deref(),
2788        "default-src 'none'; frame-ancestors 'none'",
2789    );
2790    apply_security_header(
2791        headers,
2792        HeaderName::from_static("x-dns-prefetch-control"),
2793        cfg.x_dns_prefetch_control.as_deref(),
2794        "off",
2795    );
2796
2797    if is_tls {
2798        apply_security_header(
2799            headers,
2800            header::STRICT_TRANSPORT_SECURITY,
2801            cfg.strict_transport_security.as_deref(),
2802            "max-age=63072000; includeSubDomains",
2803        );
2804    }
2805
2806    resp
2807}
2808
2809/// Set a single security header on the response, honouring the
2810/// three-state override semantic (None = default, Some("") = omit,
2811/// Some(value) = override).
2812///
2813/// Defence-in-depth: if an override value somehow reaches this point
2814/// despite [`validate_security_headers`] having approved it (e.g. a
2815/// runtime mutation on a non-`Validated` field), we log at error level
2816/// and fall back to the static default rather than panicking. The
2817/// `Validated<McpServerConfig>` type makes that path unreachable in
2818/// well-typed code paths.
2819fn apply_security_header(
2820    headers: &mut axum::http::HeaderMap,
2821    name: axum::http::HeaderName,
2822    override_value: Option<&str>,
2823    default: &'static str,
2824) {
2825    use axum::http::HeaderValue;
2826
2827    match override_value {
2828        None => {
2829            headers.insert(name, HeaderValue::from_static(default));
2830        }
2831        Some("") => {
2832            // Operator explicitly opted out of this header.
2833        }
2834        Some(v) => match HeaderValue::from_str(v) {
2835            Ok(hv) => {
2836                headers.insert(name, hv);
2837            }
2838            Err(err) => {
2839                tracing::error!(
2840                    header = %name,
2841                    error = %err,
2842                    "invalid security header override reached middleware; using default"
2843                );
2844                headers.insert(name, HeaderValue::from_static(default));
2845            }
2846        },
2847    }
2848}
2849
2850/// Validate every non-empty entry in a [`SecurityHeadersConfig`].
2851///
2852/// - `None` and `Some("")` are accepted unconditionally (use-default and
2853///   omit, respectively).
2854/// - `Some(v)` is rejected if `axum::http::HeaderValue::from_str(v)` fails.
2855/// - `strict_transport_security` additionally rejects any value
2856///   containing `preload` (case-insensitive). Operators who genuinely
2857///   want to commit to the HSTS preload list must do so via a future
2858///   explicit `with_hsts_preload(true)` builder, not by smuggling
2859///   `preload` through this knob.
2860fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2861    use axum::http::HeaderValue;
2862
2863    let fields: &[(&str, Option<&str>)] = &[
2864        (
2865            "x_content_type_options",
2866            cfg.x_content_type_options.as_deref(),
2867        ),
2868        ("x_frame_options", cfg.x_frame_options.as_deref()),
2869        ("cache_control", cfg.cache_control.as_deref()),
2870        ("referrer_policy", cfg.referrer_policy.as_deref()),
2871        (
2872            "cross_origin_opener_policy",
2873            cfg.cross_origin_opener_policy.as_deref(),
2874        ),
2875        (
2876            "cross_origin_resource_policy",
2877            cfg.cross_origin_resource_policy.as_deref(),
2878        ),
2879        (
2880            "cross_origin_embedder_policy",
2881            cfg.cross_origin_embedder_policy.as_deref(),
2882        ),
2883        ("permissions_policy", cfg.permissions_policy.as_deref()),
2884        (
2885            "x_permitted_cross_domain_policies",
2886            cfg.x_permitted_cross_domain_policies.as_deref(),
2887        ),
2888        (
2889            "content_security_policy",
2890            cfg.content_security_policy.as_deref(),
2891        ),
2892        (
2893            "x_dns_prefetch_control",
2894            cfg.x_dns_prefetch_control.as_deref(),
2895        ),
2896        (
2897            "strict_transport_security",
2898            cfg.strict_transport_security.as_deref(),
2899        ),
2900    ];
2901
2902    for (field, value) in fields {
2903        let Some(v) = value else { continue };
2904        if v.is_empty() {
2905            continue;
2906        }
2907        if let Err(err) = HeaderValue::from_str(v) {
2908            return Err(McpxError::Config(format!(
2909                "invalid security_headers.{field}: {err}"
2910            )));
2911        }
2912    }
2913
2914    if let Some(v) = cfg.strict_transport_security.as_deref()
2915        && !v.is_empty()
2916        && v.to_ascii_lowercase().contains("preload")
2917    {
2918        return Err(McpxError::Config(format!(
2919            "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2920             HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2921        )));
2922    }
2923
2924    Ok(())
2925}
2926
2927/// Append RFC 6749 §5.1 / RFC 6750 §5.4 cache and `Vary` headers required
2928/// on OAuth token-issuing responses.
2929///
2930/// `Cache-Control: no-store, max-age=0` is already applied globally by
2931/// [`security_headers_middleware`]; this middleware adds:
2932///
2933/// - `Pragma: no-cache` -- mandated by RFC 6749 §5.1 for HTTP/1.0 caches.
2934/// - `Vary: Authorization` -- mandated by RFC 6750 §5.4 for endpoints
2935///   whose response depends on the `Authorization` header.
2936///
2937/// Applied only to the OAuth proxy token-class endpoints (`/token`,
2938/// `/register`, `/introspect`, `/revoke`). `Vary` is appended (not
2939/// inserted) so any `Vary` value already present (e.g. `Accept-Encoding`
2940/// from a compression layer, or `Origin` from a CORS layer) is preserved.
2941#[cfg(feature = "oauth")]
2942async fn oauth_token_cache_headers_middleware(
2943    req: Request<Body>,
2944    next: Next,
2945) -> axum::response::Response {
2946    use axum::http::{HeaderValue, header};
2947
2948    let mut resp = next.run(req).await;
2949    let headers = resp.headers_mut();
2950    headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2951    headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2952    resp
2953}
2954
2955/// Normalize peer-address request extensions across listener branches.
2956///
2957/// The make-service installs `ConnectInfo<SocketAddr>` on the plain
2958/// listener but `ConnectInfo<TlsConnInfo>` on the TLS listener (the
2959/// latter additionally carries the connection-bound mTLS identity and
2960/// stays `pub(crate)` — see the anti-aliasing rationale on
2961/// [`TlsConnInfo`]). Application routes — in particular those merged via
2962/// [`McpServerConfig::with_extra_router`], which bypass the auth
2963/// middleware and its private fallback — could therefore not read the
2964/// peer address under TLS.
2965///
2966/// This middleware makes both branches look identical to every route and
2967/// inner middleware:
2968///
2969/// 1. mirrors the TLS peer address into `ConnectInfo<SocketAddr>` when
2970///    (and only when) it is absent, so stock axum-ecosystem extractors
2971///    work unmodified, and
2972/// 2. inserts the framework-owned [`PeerAddr`] extension on both
2973///    branches.
2974///
2975/// Precedence mirrors the auth middleware: an existing
2976/// `ConnectInfo<SocketAddr>` always wins and is never overwritten. The
2977/// peer address is deliberately not logged here.
2978async fn normalize_peer_addr_middleware(
2979    mut req: Request<Body>,
2980    next: Next,
2981) -> axum::response::Response {
2982    let direct = req
2983        .extensions()
2984        .get::<ConnectInfo<SocketAddr>>()
2985        .map(|ci| ci.0);
2986    let from_tls = req
2987        .extensions()
2988        .get::<ConnectInfo<TlsConnInfo>>()
2989        .map(|ci| ci.0.addr);
2990    if let Some(addr) = direct.or(from_tls) {
2991        if direct.is_none() {
2992            req.extensions_mut().insert(ConnectInfo(addr));
2993        }
2994        req.extensions_mut().insert(PeerAddr::new(addr));
2995    }
2996    next.run(req).await
2997}
2998
2999/// Per-IP rate limiter for `extra_router` routes, keyed by the direct
3000/// socket peer address. Same memory-bounded machinery as the tool
3001/// limiter ([`crate::rbac`]).
3002pub(crate) type ExtraRouteRateLimiter = BoundedKeyedLimiter<IpAddr>;
3003
3004/// Cap on distinct source IPs tracked by the extra-route limiter.
3005/// Mirrors the tool limiter's bound: memory stays bounded at saturation
3006/// via idle-prune + LRU eviction, at the cost of shared-fate fairness
3007/// under key spray (an attacker churning many IPs can reset quieter
3008/// legitimate IPs to fresh buckets).
3009const EXTRA_ROUTE_MAX_TRACKED_KEYS: usize = 10_000;
3010
3011/// Idle-eviction window for the extra-route limiter (15 minutes),
3012/// mirroring the tool limiter.
3013const EXTRA_ROUTE_IDLE_EVICTION: Duration = Duration::from_mins(15);
3014
3015/// Build the per-IP limiter for `extra_router` routes.
3016///
3017/// `per_minute` and `burst` are validated nonzero by
3018/// [`McpServerConfig::validate`]; the `NonZeroU32` fallbacks here are
3019/// defensive only. `burst` overrides governor's default bucket capacity
3020/// (burst = rate).
3021fn build_extra_route_rate_limiter(
3022    per_minute: u32,
3023    burst: Option<u32>,
3024) -> Arc<ExtraRouteRateLimiter> {
3025    let rate = std::num::NonZeroU32::new(per_minute.max(1)).unwrap_or(std::num::NonZeroU32::MIN);
3026    let mut quota = governor::Quota::per_minute(rate);
3027    if let Some(b) = burst.and_then(std::num::NonZeroU32::new) {
3028        quota = quota.allow_burst(b);
3029    }
3030    Arc::new(BoundedKeyedLimiter::new(
3031        quota,
3032        EXTRA_ROUTE_MAX_TRACKED_KEYS,
3033        EXTRA_ROUTE_IDLE_EVICTION,
3034    ))
3035}
3036
3037/// Per-IP rate limit middleware for `extra_router` routes.
3038///
3039/// Applied to the application-supplied router **before** it is merged
3040/// into the top-level router, so it wraps exactly the extra routes
3041/// (and their fallback, if any) and nothing else — `/mcp`, health,
3042/// admin, and OAuth endpoints are never affected. Outer layers (origin
3043/// check, peer-address normalization, security headers, metrics) still
3044/// wrap these routes and run first, so both `ConnectInfo` forms are
3045/// populated by the time this middleware reads them.
3046///
3047/// Semantics mirror the tool/auth limiters exactly: keyed by the
3048/// direct peer `IpAddr` (no `X-Forwarded-For`), fail-open when no peer
3049/// address is present (cannot happen under [`serve`]), and on limit a
3050/// plain-text 429 via [`McpxError::RateLimitedFor`] carrying a
3051/// `Retry-After` header (delta-seconds), consistent with every other
3052/// limiter in the crate.
3053async fn extra_route_rate_limit_middleware(
3054    limiter: Arc<ExtraRouteRateLimiter>,
3055    req: Request<Body>,
3056    next: Next,
3057) -> axum::response::Response {
3058    let peer_ip: Option<IpAddr> = req
3059        .extensions()
3060        .get::<ConnectInfo<SocketAddr>>()
3061        .map(|ci| ci.0.ip())
3062        .or_else(|| {
3063            req.extensions()
3064                .get::<ConnectInfo<TlsConnInfo>>()
3065                .map(|ci| ci.0.addr.ip())
3066        });
3067    if let Some(ip) = peer_ip
3068        && let Err(wait) = limiter.check_key_wait(&ip)
3069    {
3070        tracing::warn!(%ip, "extra route request rate limited");
3071        return McpxError::RateLimitedFor {
3072            message: "too many requests to application routes from this source".into(),
3073            retry_after: wait,
3074        }
3075        .into_response();
3076    }
3077    next.run(req).await
3078}
3079
3080/// Per the MCP spec: if the Origin header is present and its value is not in
3081/// the allowed list, respond with 403 Forbidden. Requests without an Origin
3082/// header are allowed through (e.g. non-browser clients like curl, SDKs).
3083async fn origin_check_middleware(
3084    allowed: Arc<[String]>,
3085    log_request_headers: bool,
3086    req: Request<Body>,
3087    next: Next,
3088) -> axum::response::Response {
3089    let method = req.method().clone();
3090    let path = req.uri().path().to_owned();
3091
3092    log_incoming_request(&method, &path, req.headers(), log_request_headers);
3093
3094    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
3095        let origin_str = origin.to_str().unwrap_or("");
3096        if !allowed.iter().any(|a| a == origin_str) {
3097            tracing::warn!(
3098                origin = origin_str,
3099                %method,
3100                %path,
3101                allowed = ?&*allowed,
3102                "rejected request: Origin not allowed"
3103            );
3104            return (
3105                axum::http::StatusCode::FORBIDDEN,
3106                "Forbidden: Origin not allowed",
3107            )
3108                .into_response();
3109        }
3110    }
3111    next.run(req).await
3112}
3113
3114/// Emit a DEBUG log for an incoming request, optionally including the full
3115/// (redacted) header set.
3116fn log_incoming_request(
3117    method: &axum::http::Method,
3118    path: &str,
3119    headers: &axum::http::HeaderMap,
3120    log_request_headers: bool,
3121) {
3122    if log_request_headers {
3123        tracing::debug!(
3124            %method,
3125            %path,
3126            headers = %format_request_headers_for_log(headers),
3127            "incoming request"
3128        );
3129    } else {
3130        tracing::debug!(%method, %path, "incoming request");
3131    }
3132}
3133
3134fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
3135    headers
3136        .iter()
3137        .map(|(k, v)| {
3138            let name = k.as_str();
3139            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
3140                format!("{name}: [REDACTED]")
3141            } else {
3142                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
3143            }
3144        })
3145        .collect::<Vec<_>>()
3146        .join(", ")
3147}
3148
3149// -- stdio transport --
3150
3151/// Serve an MCP server over stdin/stdout (stdio transport).
3152///
3153/// # Security warnings
3154///
3155/// - **No authentication**: the parent process has full, unrestricted access.
3156/// - **No RBAC**: all tools are available regardless of policy.
3157/// - **No TLS**: messages travel over OS pipes in plaintext.
3158/// - **Single client**: only the parent process can connect.
3159/// - **No Origin validation**: not applicable to stdio.
3160///
3161/// Use this only when the MCP client spawns the server as a trusted subprocess
3162/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
3163/// use `serve()` (Streamable HTTP) instead.
3164///
3165/// # Errors
3166///
3167/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
3168/// transport disconnects unexpectedly.
3169// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
3170// macro expansion in this 18-line function (info/warn/info + two matches).
3171// There is nothing meaningful to extract; the allow stays.
3172#[allow(
3173    clippy::cognitive_complexity,
3174    reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
3175)]
3176pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
3177where
3178    H: ServerHandler + 'static,
3179{
3180    use rmcp::ServiceExt as _;
3181
3182    tracing::info!("stdio transport: serving on stdin/stdout");
3183    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
3184
3185    let transport = rmcp::transport::io::stdio();
3186
3187    let service = handler
3188        .serve(transport)
3189        .await
3190        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
3191
3192    if let Err(e) = service.waiting().await {
3193        tracing::warn!(error = %e, "stdio session ended with error");
3194    }
3195    tracing::info!("stdio session ended");
3196    Ok(())
3197}
3198
3199#[cfg(test)]
3200mod tests {
3201    #![allow(
3202        clippy::unwrap_used,
3203        clippy::expect_used,
3204        clippy::panic,
3205        clippy::indexing_slicing,
3206        clippy::unwrap_in_result,
3207        clippy::print_stdout,
3208        clippy::print_stderr,
3209        deprecated,
3210        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
3211    )]
3212    use std::{sync::Arc, time::Duration};
3213
3214    use axum::{
3215        body::Body,
3216        http::{Request, StatusCode, header},
3217        response::IntoResponse,
3218    };
3219    use http_body_util::BodyExt;
3220    use tower::ServiceExt as _;
3221
3222    use super::*;
3223
3224    // -- McpServerConfig --
3225
3226    #[test]
3227    fn server_config_new_defaults() {
3228        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
3229        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
3230        assert_eq!(cfg.name, "test-server");
3231        assert_eq!(cfg.version, "1.0.0");
3232        assert!(cfg.tls_cert_path.is_none());
3233        assert!(cfg.tls_key_path.is_none());
3234        assert!(cfg.auth.is_none());
3235        assert!(cfg.rbac.is_none());
3236        assert!(cfg.allowed_origins.is_empty());
3237        assert!(cfg.tool_rate_limit.is_none());
3238        assert!(cfg.readiness_check.is_none());
3239        assert_eq!(cfg.max_request_body, 1024 * 1024);
3240        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
3241        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
3242        assert!(!cfg.log_request_headers);
3243        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
3244        assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
3245    }
3246
3247    #[test]
3248    fn tls_handshake_builders_set_fields() {
3249        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3250            .with_tls_handshake_timeout(Duration::from_secs(3))
3251            .with_max_concurrent_tls_handshakes(64);
3252        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
3253        assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
3254    }
3255
3256    #[test]
3257    fn validate_rejects_zero_tls_handshake_timeout() {
3258        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3259            .with_tls_handshake_timeout(Duration::ZERO);
3260        let err = cfg.validate().expect_err("zero handshake timeout");
3261        assert!(err.to_string().contains("tls_handshake_timeout"));
3262    }
3263
3264    #[test]
3265    fn validate_rejects_zero_max_concurrent_tls_handshakes() {
3266        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3267            .with_max_concurrent_tls_handshakes(0);
3268        let err = cfg.validate().expect_err("zero handshake concurrency");
3269        assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
3270    }
3271
3272    #[test]
3273    fn validate_consumes_and_proves() {
3274        // Valid config -> Validated wrapper, original is consumed.
3275        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3276        let validated = cfg.validate().expect("valid config");
3277        // as_inner() gives read-only access to inner fields.
3278        assert_eq!(validated.as_inner().name, "test-server");
3279        // into_inner recovers the raw value.
3280        let raw = validated.into_inner();
3281        assert_eq!(raw.name, "test-server");
3282
3283        // Invalid config (zero max_request_body) -> Err.
3284        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3285        bad.max_request_body = 0;
3286        assert!(bad.validate().is_err(), "zero body cap must fail validate");
3287    }
3288
3289    #[test]
3290    fn validate_rejects_zero_max_concurrent_requests() {
3291        let cfg =
3292            McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
3293        let err = cfg.validate().expect_err("zero concurrency cap must fail");
3294        assert!(
3295            format!("{err}").contains("max_concurrent_requests"),
3296            "error should mention max_concurrent_requests, got: {err}"
3297        );
3298    }
3299
3300    #[test]
3301    fn validate_rejects_zero_max_tracked_keys() {
3302        // Defaults mirror auth::default_max_attempts / default_idle_eviction
3303        // (module-private in auth.rs); spelled out here for review clarity.
3304        let rl = crate::auth::RateLimitConfig {
3305            max_attempts_per_minute: 30,
3306            pre_auth_max_per_minute: None,
3307            max_tracked_keys: 0,
3308            idle_eviction: Duration::from_secs(15 * 60),
3309            burst: None,
3310            pre_auth_burst: None,
3311        };
3312        let auth_cfg = AuthConfig {
3313            enabled: true,
3314            api_keys: Vec::new(),
3315            mtls: None,
3316            rate_limit: Some(rl),
3317            #[cfg(feature = "oauth")]
3318            oauth: None,
3319        };
3320        let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
3321        let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
3322        assert!(
3323            format!("{err}").contains("max_tracked_keys"),
3324            "error should mention max_tracked_keys, got: {err}"
3325        );
3326    }
3327
3328    #[test]
3329    fn derive_allowed_hosts_includes_public_host() {
3330        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
3331        assert!(
3332            hosts.iter().any(|h| h == "mcp.example.com"),
3333            "public_url host must be allowed"
3334        );
3335    }
3336
3337    #[test]
3338    fn derive_allowed_hosts_includes_bind_authority() {
3339        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
3340        assert!(
3341            hosts.iter().any(|h| h == "127.0.0.1"),
3342            "bind host must be allowed"
3343        );
3344        assert!(
3345            hosts.iter().any(|h| h == "127.0.0.1:8080"),
3346            "bind authority must be allowed"
3347        );
3348    }
3349
3350    // -- healthz --
3351
3352    #[tokio::test]
3353    async fn healthz_returns_ok_json() {
3354        let resp = healthz().await.into_response();
3355        assert_eq!(resp.status(), StatusCode::OK);
3356        let body = resp.into_body().collect().await.unwrap().to_bytes();
3357        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3358        assert_eq!(json["status"], "ok");
3359        assert!(
3360            json.get("name").is_none(),
3361            "healthz must not expose server name"
3362        );
3363        assert!(
3364            json.get("version").is_none(),
3365            "healthz must not expose version"
3366        );
3367    }
3368
3369    // -- readyz --
3370
3371    #[tokio::test]
3372    async fn readyz_returns_ok_when_ready() {
3373        let check: ReadinessCheck =
3374            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
3375        let resp = readyz(check).await.into_response();
3376        assert_eq!(resp.status(), StatusCode::OK);
3377        let body = resp.into_body().collect().await.unwrap().to_bytes();
3378        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3379        assert_eq!(json["ready"], true);
3380        assert!(
3381            json.get("name").is_none(),
3382            "readyz must not expose server name"
3383        );
3384        assert!(
3385            json.get("version").is_none(),
3386            "readyz must not expose version"
3387        );
3388        assert_eq!(json["db"], "connected");
3389    }
3390
3391    #[tokio::test]
3392    async fn readyz_returns_503_when_not_ready() {
3393        let check: ReadinessCheck =
3394            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3395        let resp = readyz(check).await.into_response();
3396        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3397    }
3398
3399    #[tokio::test]
3400    async fn readyz_returns_503_when_ready_missing() {
3401        let check: ReadinessCheck =
3402            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3403        let resp = readyz(check).await.into_response();
3404        // Missing "ready" field defaults to false -> 503
3405        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3406    }
3407
3408    // -- normalize_peer_addr_middleware / PeerAddr --
3409
3410    /// Build a test router that reports the request's peer-address
3411    /// extensions as `"<ConnectInfo>|<PeerAddr>"` (empty when absent).
3412    fn peer_probe_router() -> axum::Router {
3413        async fn probe(req: Request<Body>) -> String {
3414            let ci = req
3415                .extensions()
3416                .get::<ConnectInfo<SocketAddr>>()
3417                .map(|c| c.0.to_string())
3418                .unwrap_or_default();
3419            let pa = req
3420                .extensions()
3421                .get::<PeerAddr>()
3422                .map(|p| p.addr.to_string())
3423                .unwrap_or_default();
3424            format!("{ci}|{pa}")
3425        }
3426        axum::Router::new()
3427            .route("/probe", axum::routing::get(probe))
3428            .layer(axum::middleware::from_fn(normalize_peer_addr_middleware))
3429    }
3430
3431    async fn body_string(resp: axum::response::Response) -> String {
3432        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
3433        String::from_utf8(bytes.to_vec()).unwrap()
3434    }
3435
3436    #[tokio::test]
3437    async fn normalize_preserves_existing_connect_info_and_mirrors_peer_addr() {
3438        // Precedence proof: when both extensions exist with DIFFERENT
3439        // addresses, ConnectInfo<SocketAddr> wins and is never overwritten.
3440        let plain: SocketAddr = "10.0.0.1:1111".parse().unwrap();
3441        let tls: SocketAddr = "10.0.0.2:2222".parse().unwrap();
3442        let req = Request::builder()
3443            .uri("/probe")
3444            .extension(ConnectInfo(plain))
3445            .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3446            .body(Body::empty())
3447            .unwrap();
3448        let resp = peer_probe_router().oneshot(req).await.unwrap();
3449        assert_eq!(resp.status(), StatusCode::OK);
3450        assert_eq!(body_string(resp).await, format!("{plain}|{plain}"));
3451    }
3452
3453    #[tokio::test]
3454    async fn normalize_inserts_connect_info_and_peer_addr_from_tls() {
3455        let tls: SocketAddr = "192.168.1.7:50443".parse().unwrap();
3456        let req = Request::builder()
3457            .uri("/probe")
3458            .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3459            .body(Body::empty())
3460            .unwrap();
3461        let resp = peer_probe_router().oneshot(req).await.unwrap();
3462        assert_eq!(resp.status(), StatusCode::OK);
3463        assert_eq!(body_string(resp).await, format!("{tls}|{tls}"));
3464    }
3465
3466    #[tokio::test]
3467    async fn normalize_no_op_without_any_connect_info() {
3468        let req = Request::builder()
3469            .uri("/probe")
3470            .body(Body::empty())
3471            .unwrap();
3472        let resp = peer_probe_router().oneshot(req).await.unwrap();
3473        assert_eq!(resp.status(), StatusCode::OK);
3474        assert_eq!(body_string(resp).await, "|");
3475    }
3476
3477    #[tokio::test]
3478    async fn peer_addr_extractor_rejects_when_absent() {
3479        async fn h(peer: PeerAddr) -> String {
3480            peer.addr.to_string()
3481        }
3482        let app = axum::Router::new().route("/p", axum::routing::get(h));
3483        let req = Request::builder().uri("/p").body(Body::empty()).unwrap();
3484        let resp = app.oneshot(req).await.unwrap();
3485        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
3486    }
3487
3488    #[tokio::test]
3489    async fn peer_addr_extractor_returns_value_when_present() {
3490        async fn h(peer: PeerAddr) -> String {
3491            peer.addr.to_string()
3492        }
3493        let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
3494        let app = axum::Router::new().route("/p", axum::routing::get(h));
3495        let req = Request::builder()
3496            .uri("/p")
3497            .extension(PeerAddr::new(addr))
3498            .body(Body::empty())
3499            .unwrap();
3500        let resp = app.oneshot(req).await.unwrap();
3501        assert_eq!(resp.status(), StatusCode::OK);
3502        assert_eq!(body_string(resp).await, addr.to_string());
3503    }
3504
3505    #[tokio::test]
3506    async fn peer_addr_via_extension_extractor() {
3507        async fn h(axum::Extension(peer): axum::Extension<PeerAddr>) -> String {
3508            peer.addr.to_string()
3509        }
3510        let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
3511        let app = axum::Router::new().route("/p", axum::routing::get(h));
3512        let req = Request::builder()
3513            .uri("/p")
3514            .extension(PeerAddr::new(addr))
3515            .body(Body::empty())
3516            .unwrap();
3517        let resp = app.oneshot(req).await.unwrap();
3518        assert_eq!(resp.status(), StatusCode::OK);
3519        assert_eq!(body_string(resp).await, addr.to_string());
3520    }
3521
3522    // -- extra_route_rate_limit_middleware --
3523
3524    /// Probe router with the extra-route limiter installed, mirroring
3525    /// the layer-before-merge wiring in `build_app_router`.
3526    fn limited_router(per_minute: u32) -> axum::Router {
3527        limited_router_with_burst(per_minute, None)
3528    }
3529
3530    /// Probe router with an explicit burst capacity.
3531    fn limited_router_with_burst(per_minute: u32, burst: Option<u32>) -> axum::Router {
3532        let limiter = build_extra_route_rate_limiter(per_minute, burst);
3533        axum::Router::new()
3534            .route("/limited", axum::routing::get(|| async { "ok" }))
3535            .layer(axum::middleware::from_fn(move |req, next| {
3536                let l = Arc::clone(&limiter);
3537                extra_route_rate_limit_middleware(l, req, next)
3538            }))
3539    }
3540
3541    fn limited_req(ip: &str) -> Request<Body> {
3542        let addr: SocketAddr = format!("{ip}:40000").parse().unwrap();
3543        Request::builder()
3544            .uri("/limited")
3545            .extension(ConnectInfo(addr))
3546            .body(Body::empty())
3547            .unwrap()
3548    }
3549
3550    #[tokio::test]
3551    async fn extra_route_limiter_denies_over_quota() {
3552        let app = limited_router(2);
3553        for i in 0..2 {
3554            let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3555            assert_eq!(resp.status(), StatusCode::OK, "request {i} should pass");
3556        }
3557        let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3558        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3559        let body = body_string(resp).await;
3560        assert!(
3561            body.contains("too many requests to application routes"),
3562            "deny body should match the limiter message, got: {body}"
3563        );
3564    }
3565
3566    #[tokio::test]
3567    async fn extra_route_limiter_isolates_keys() {
3568        let app = limited_router(2);
3569        for _ in 0..2 {
3570            let resp = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3571            assert_eq!(resp.status(), StatusCode::OK);
3572        }
3573        let exhausted = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3574        assert_eq!(exhausted.status(), StatusCode::TOO_MANY_REQUESTS);
3575        // A different source IP still has a fresh bucket.
3576        let other = app.clone().oneshot(limited_req("10.3.3.3")).await.unwrap();
3577        assert_eq!(other.status(), StatusCode::OK);
3578    }
3579
3580    #[tokio::test]
3581    async fn extra_route_limiter_fails_open_without_peer() {
3582        let app = limited_router(1);
3583        for i in 0..3 {
3584            let req = Request::builder()
3585                .uri("/limited")
3586                .body(Body::empty())
3587                .unwrap();
3588            let resp = app.clone().oneshot(req).await.unwrap();
3589            assert_eq!(
3590                resp.status(),
3591                StatusCode::OK,
3592                "request {i} should fail open"
3593            );
3594        }
3595    }
3596
3597    #[tokio::test]
3598    async fn extra_route_limiter_extracts_tls_conn_info() {
3599        let app = limited_router(2);
3600        let mk = || {
3601            let addr: SocketAddr = "192.168.9.9:55555".parse().unwrap();
3602            Request::builder()
3603                .uri("/limited")
3604                .extension(ConnectInfo(TlsConnInfo::new(addr, None)))
3605                .body(Body::empty())
3606                .unwrap()
3607        };
3608        for _ in 0..2 {
3609            assert_eq!(
3610                app.clone().oneshot(mk()).await.unwrap().status(),
3611                StatusCode::OK
3612            );
3613        }
3614        let resp = app.clone().oneshot(mk()).await.unwrap();
3615        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3616    }
3617
3618    #[test]
3619    fn validate_rejects_zero_extra_route_rate_limit() {
3620        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3621            .with_extra_route_rate_limit(0);
3622        let err = cfg.validate().expect_err("zero extra route rate limit");
3623        assert!(err.to_string().contains("extra_route_rate_limit"));
3624    }
3625
3626    #[tokio::test]
3627    async fn extra_route_limiter_burst_allows_initial_spike() {
3628        let app = limited_router_with_burst(1, Some(3));
3629        for i in 0..3 {
3630            let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
3631            assert_eq!(resp.status(), StatusCode::OK, "burst request {i}");
3632        }
3633        let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
3634        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3635    }
3636
3637    #[tokio::test]
3638    async fn extra_route_limiter_deny_sets_retry_after() {
3639        let app = limited_router(1);
3640        let ok = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
3641        assert_eq!(ok.status(), StatusCode::OK);
3642        let denied = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
3643        assert_eq!(denied.status(), StatusCode::TOO_MANY_REQUESTS);
3644        let retry_after = denied
3645            .headers()
3646            .get(header::RETRY_AFTER)
3647            .expect("Retry-After present")
3648            .to_str()
3649            .unwrap()
3650            .parse::<u64>()
3651            .unwrap();
3652        assert!(retry_after >= 1, "delta-seconds must be >= 1");
3653    }
3654
3655    #[test]
3656    fn validate_rejects_zero_burst_knobs() {
3657        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3658            .with_tool_rate_limit(10)
3659            .with_tool_rate_limit_burst(0)
3660            .validate()
3661            .expect_err("zero tool burst");
3662        assert!(err.to_string().contains("tool_rate_limit_burst"));
3663
3664        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3665            .with_extra_route_rate_limit(10)
3666            .with_extra_route_rate_limit_burst(0)
3667            .validate()
3668            .expect_err("zero extra route burst");
3669        assert!(err.to_string().contains("extra_route_rate_limit_burst"));
3670    }
3671
3672    #[test]
3673    fn validate_rejects_orphan_burst_knobs() {
3674        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3675            .with_tool_rate_limit_burst(5)
3676            .validate()
3677            .expect_err("orphan tool burst");
3678        assert!(err.to_string().contains("requires tool_rate_limit"));
3679
3680        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3681            .with_extra_route_rate_limit_burst(5)
3682            .validate()
3683            .expect_err("orphan extra route burst");
3684        assert!(err.to_string().contains("requires extra_route_rate_limit"));
3685    }
3686
3687    #[test]
3688    fn validate_rejects_zero_auth_bursts() {
3689        let auth = AuthConfig::with_keys(vec![])
3690            .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_burst(0));
3691        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3692            .with_auth(auth)
3693            .validate()
3694            .expect_err("zero auth burst");
3695        assert!(err.to_string().contains("rate_limit.burst"));
3696
3697        let auth = AuthConfig::with_keys(vec![])
3698            .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(0));
3699        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3700            .with_auth(auth)
3701            .validate()
3702            .expect_err("zero pre-auth burst");
3703        assert!(err.to_string().contains("pre_auth_burst"));
3704    }
3705
3706    /// `pre_auth_burst` without `pre_auth_max_per_minute` is LEGAL: the
3707    /// pre-auth base rate always resolves (max_attempts_per_minute x 10).
3708    #[test]
3709    fn validate_accepts_pre_auth_burst_without_explicit_pre_auth_rate() {
3710        let auth = AuthConfig::with_keys(vec![])
3711            .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(50));
3712        let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_auth(auth);
3713        assert!(cfg.validate().is_ok(), "pre_auth_burst has no orphan rule");
3714    }
3715
3716    // -- origin_check_middleware --
3717
3718    /// Build a test router with origin check middleware and a simple handler.
3719    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
3720        let allowed: Arc<[String]> = Arc::from(origins);
3721        axum::Router::new()
3722            .route("/test", axum::routing::get(|| async { "ok" }))
3723            .layer(axum::middleware::from_fn(move |req, next| {
3724                let a = Arc::clone(&allowed);
3725                origin_check_middleware(a, log_request_headers, req, next)
3726            }))
3727    }
3728
3729    #[tokio::test]
3730    async fn origin_allowed_passes() {
3731        let app = origin_router(vec!["http://localhost:3000".into()], false);
3732        let req = Request::builder()
3733            .uri("/test")
3734            .header(header::ORIGIN, "http://localhost:3000")
3735            .body(Body::empty())
3736            .unwrap();
3737        let resp = app.oneshot(req).await.unwrap();
3738        assert_eq!(resp.status(), StatusCode::OK);
3739    }
3740
3741    #[tokio::test]
3742    async fn origin_rejected_returns_403() {
3743        let app = origin_router(vec!["http://localhost:3000".into()], false);
3744        let req = Request::builder()
3745            .uri("/test")
3746            .header(header::ORIGIN, "http://evil.com")
3747            .body(Body::empty())
3748            .unwrap();
3749        let resp = app.oneshot(req).await.unwrap();
3750        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3751    }
3752
3753    #[tokio::test]
3754    async fn no_origin_header_passes() {
3755        let app = origin_router(vec!["http://localhost:3000".into()], false);
3756        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3757        let resp = app.oneshot(req).await.unwrap();
3758        assert_eq!(resp.status(), StatusCode::OK);
3759    }
3760
3761    #[tokio::test]
3762    async fn empty_allowlist_rejects_any_origin() {
3763        let app = origin_router(vec![], false);
3764        let req = Request::builder()
3765            .uri("/test")
3766            .header(header::ORIGIN, "http://anything.com")
3767            .body(Body::empty())
3768            .unwrap();
3769        let resp = app.oneshot(req).await.unwrap();
3770        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3771    }
3772
3773    #[tokio::test]
3774    async fn empty_allowlist_passes_without_origin() {
3775        let app = origin_router(vec![], false);
3776        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3777        let resp = app.oneshot(req).await.unwrap();
3778        assert_eq!(resp.status(), StatusCode::OK);
3779    }
3780
3781    #[test]
3782    fn format_request_headers_redacts_sensitive_values() {
3783        let mut headers = axum::http::HeaderMap::new();
3784        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
3785        headers.insert("cookie", "sid=abc".parse().unwrap());
3786        headers.insert("x-request-id", "req-123".parse().unwrap());
3787
3788        let out = format_request_headers_for_log(&headers);
3789        assert!(out.contains("authorization: [REDACTED]"));
3790        assert!(out.contains("cookie: [REDACTED]"));
3791        assert!(out.contains("x-request-id: req-123"));
3792        assert!(!out.contains("secret-token"));
3793    }
3794
3795    // -- security_headers_middleware --
3796
3797    fn security_router(is_tls: bool) -> axum::Router {
3798        security_router_with(is_tls, SecurityHeadersConfig::default())
3799    }
3800
3801    fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
3802        let cfg = Arc::new(cfg);
3803        axum::Router::new()
3804            .route("/test", axum::routing::get(|| async { "ok" }))
3805            .layer(axum::middleware::from_fn(move |req, next| {
3806                let c = Arc::clone(&cfg);
3807                security_headers_middleware(is_tls, c, req, next)
3808            }))
3809    }
3810
3811    #[tokio::test]
3812    async fn security_headers_set_on_response() {
3813        let app = security_router(false);
3814        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3815        let resp = app.oneshot(req).await.unwrap();
3816        assert_eq!(resp.status(), StatusCode::OK);
3817
3818        let h = resp.headers();
3819        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
3820        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
3821        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
3822        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
3823        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
3824        assert_eq!(
3825            h.get("cross-origin-resource-policy").unwrap(),
3826            "same-origin"
3827        );
3828        assert_eq!(
3829            h.get("cross-origin-embedder-policy").unwrap(),
3830            "require-corp"
3831        );
3832        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
3833        assert!(
3834            h.get("permissions-policy")
3835                .unwrap()
3836                .to_str()
3837                .unwrap()
3838                .contains("camera=()"),
3839            "permissions-policy must restrict browser features"
3840        );
3841        assert_eq!(
3842            h.get("content-security-policy").unwrap(),
3843            "default-src 'none'; frame-ancestors 'none'"
3844        );
3845        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
3846        // No HSTS when TLS is off.
3847        assert!(h.get("strict-transport-security").is_none());
3848    }
3849
3850    #[tokio::test]
3851    async fn hsts_set_when_tls_enabled() {
3852        let app = security_router(true);
3853        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3854        let resp = app.oneshot(req).await.unwrap();
3855
3856        let hsts = resp.headers().get("strict-transport-security").unwrap();
3857        assert!(
3858            hsts.to_str().unwrap().contains("max-age=63072000"),
3859            "HSTS must set 2-year max-age"
3860        );
3861    }
3862
3863    // -- SecurityHeadersConfig validation + override semantics --
3864
3865    /// Build a minimal config with a custom SecurityHeadersConfig and
3866    /// drive it through `check()`. Returns the result so individual
3867    /// tests can assert on success or specific error messages.
3868    fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
3869        let cfg =
3870            McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
3871        cfg.check()
3872    }
3873
3874    #[test]
3875    fn security_headers_config_default_validates() {
3876        check_with_security_headers(SecurityHeadersConfig::default())
3877            .expect("default SecurityHeadersConfig must validate");
3878    }
3879
3880    #[test]
3881    fn security_headers_config_validate_accepts_empty_string() {
3882        // All twelve fields explicitly set to "" -> omit-everything mode.
3883        let h = SecurityHeadersConfig {
3884            x_content_type_options: Some(String::new()),
3885            x_frame_options: Some(String::new()),
3886            cache_control: Some(String::new()),
3887            referrer_policy: Some(String::new()),
3888            cross_origin_opener_policy: Some(String::new()),
3889            cross_origin_resource_policy: Some(String::new()),
3890            cross_origin_embedder_policy: Some(String::new()),
3891            permissions_policy: Some(String::new()),
3892            x_permitted_cross_domain_policies: Some(String::new()),
3893            content_security_policy: Some(String::new()),
3894            x_dns_prefetch_control: Some(String::new()),
3895            strict_transport_security: Some(String::new()),
3896        };
3897        check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
3898    }
3899
3900    #[test]
3901    fn security_headers_config_validate_rejects_bad_value() {
3902        // 0x07 (BEL) is not a valid HTTP header value char.
3903        let h = SecurityHeadersConfig {
3904            referrer_policy: Some("\u{0007}".into()),
3905            ..SecurityHeadersConfig::default()
3906        };
3907        let err = check_with_security_headers(h)
3908            .expect_err("control char in referrer_policy must reject");
3909        let msg = err.to_string();
3910        assert!(
3911            msg.contains("referrer_policy"),
3912            "error must name the offending field, got: {msg}"
3913        );
3914    }
3915
3916    #[test]
3917    fn security_headers_config_validate_rejects_hsts_preload() {
3918        let h = SecurityHeadersConfig {
3919            strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
3920            ..SecurityHeadersConfig::default()
3921        };
3922        let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
3923        let msg = err.to_string();
3924        assert!(
3925            msg.contains("strict_transport_security"),
3926            "error must name the field, got: {msg}"
3927        );
3928        assert!(
3929            msg.to_lowercase().contains("preload"),
3930            "error must mention `preload`, got: {msg}"
3931        );
3932    }
3933
3934    #[test]
3935    fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
3936        // Case-insensitive match.
3937        let h = SecurityHeadersConfig {
3938            strict_transport_security: Some("max-age=600; PRELOAD".into()),
3939            ..SecurityHeadersConfig::default()
3940        };
3941        check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3942    }
3943
3944    #[tokio::test]
3945    async fn security_headers_override_honored() {
3946        // Override X-Frame-Options to SAMEORIGIN.
3947        let h = SecurityHeadersConfig {
3948            x_frame_options: Some("SAMEORIGIN".into()),
3949            ..SecurityHeadersConfig::default()
3950        };
3951        let app = security_router_with(false, h);
3952        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3953        let resp = app.oneshot(req).await.unwrap();
3954        assert_eq!(resp.status(), StatusCode::OK);
3955
3956        let xfo = resp.headers().get("x-frame-options").unwrap();
3957        assert_eq!(xfo, "SAMEORIGIN");
3958    }
3959
3960    #[tokio::test]
3961    async fn security_headers_empty_string_omits() {
3962        // Empty string on referrer-policy -> header absent.
3963        let h = SecurityHeadersConfig {
3964            referrer_policy: Some(String::new()),
3965            ..SecurityHeadersConfig::default()
3966        };
3967        let app = security_router_with(false, h);
3968        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3969        let resp = app.oneshot(req).await.unwrap();
3970        assert_eq!(resp.status(), StatusCode::OK);
3971
3972        assert!(
3973            resp.headers().get("referrer-policy").is_none(),
3974            "Some(\"\") must omit the header"
3975        );
3976        // Other defaults should still be present.
3977        assert_eq!(
3978            resp.headers().get("x-content-type-options").unwrap(),
3979            "nosniff"
3980        );
3981    }
3982
3983    #[tokio::test]
3984    async fn security_headers_hsts_only_when_tls() {
3985        // HSTS override is irrelevant when TLS is off.
3986        let h = SecurityHeadersConfig {
3987            strict_transport_security: Some("max-age=600".into()),
3988            ..SecurityHeadersConfig::default()
3989        };
3990        let app = security_router_with(false, h);
3991        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3992        let resp = app.oneshot(req).await.unwrap();
3993        assert!(
3994            resp.headers().get("strict-transport-security").is_none(),
3995            "HSTS must remain absent on plaintext deployments even with override"
3996        );
3997    }
3998
3999    // -- oauth_token_cache_headers_middleware --
4000
4001    #[cfg(feature = "oauth")]
4002    #[tokio::test]
4003    async fn oauth_token_cache_headers_set_pragma_and_vary() {
4004        let app = axum::Router::new()
4005            .route("/token", axum::routing::post(|| async { "{}" }))
4006            .layer(axum::middleware::from_fn(
4007                oauth_token_cache_headers_middleware,
4008            ));
4009        let req = Request::builder()
4010            .method("POST")
4011            .uri("/token")
4012            .body(Body::from("{}"))
4013            .unwrap();
4014        let resp = app.oneshot(req).await.unwrap();
4015        assert_eq!(resp.status(), StatusCode::OK);
4016
4017        let h = resp.headers();
4018        assert_eq!(
4019            h.get("pragma").unwrap(),
4020            "no-cache",
4021            "RFC 6749 §5.1: token responses must set Pragma: no-cache"
4022        );
4023        let vary_values: Vec<String> = h
4024            .get_all("vary")
4025            .iter()
4026            .filter_map(|v| v.to_str().ok().map(str::to_owned))
4027            .collect();
4028        assert!(
4029            vary_values
4030                .iter()
4031                .any(|v| v.eq_ignore_ascii_case("Authorization")),
4032            "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
4033        );
4034    }
4035
4036    #[cfg(feature = "oauth")]
4037    #[tokio::test]
4038    async fn oauth_token_cache_headers_preserve_existing_vary() {
4039        // Simulates a handler/layer that already set `Vary: Accept-Encoding`
4040        // (e.g. compression). Our middleware must APPEND, not REPLACE.
4041        let app = axum::Router::new()
4042            .route(
4043                "/token",
4044                axum::routing::post(|| async {
4045                    axum::response::Response::builder()
4046                        .header("vary", "Accept-Encoding")
4047                        .body(axum::body::Body::from("{}"))
4048                        .unwrap()
4049                }),
4050            )
4051            .layer(axum::middleware::from_fn(
4052                oauth_token_cache_headers_middleware,
4053            ));
4054        let req = Request::builder()
4055            .method("POST")
4056            .uri("/token")
4057            .body(Body::empty())
4058            .unwrap();
4059        let resp = app.oneshot(req).await.unwrap();
4060
4061        let vary: Vec<String> = resp
4062            .headers()
4063            .get_all("vary")
4064            .iter()
4065            .filter_map(|v| v.to_str().ok().map(str::to_owned))
4066            .collect();
4067        assert!(
4068            vary.iter().any(|v| v.contains("Accept-Encoding")),
4069            "must preserve pre-existing Vary value, got {vary:?}"
4070        );
4071        assert!(
4072            vary.iter().any(|v| v.contains("Authorization")),
4073            "must append Authorization to Vary, got {vary:?}"
4074        );
4075    }
4076
4077    // -- version endpoint --
4078
4079    #[test]
4080    fn version_payload_contains_expected_fields() {
4081        let v = version_payload("my-server", "1.2.3");
4082        assert_eq!(v["name"], "my-server");
4083        assert_eq!(v["version"], "1.2.3");
4084        assert!(v["build_git_sha"].is_string());
4085        assert!(v["build_timestamp"].is_string());
4086        assert!(v["rust_version"].is_string());
4087        assert!(v["mcpx_version"].is_string());
4088    }
4089
4090    // -- concurrency limit layer --
4091
4092    #[tokio::test]
4093    async fn concurrency_limit_layer_composes_and_serves() {
4094        // We only assert the layer stack compiles and a single request
4095        // below the cap still succeeds. True back-pressure behaviour
4096        // requires a live HTTP server and is covered by integration tests.
4097        let app = axum::Router::new()
4098            .route("/ok", axum::routing::get(|| async { "ok" }))
4099            .layer(
4100                tower::ServiceBuilder::new()
4101                    .layer(axum::error_handling::HandleErrorLayer::new(
4102                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
4103                    ))
4104                    .layer(tower::load_shed::LoadShedLayer::new())
4105                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
4106            );
4107        let resp = app
4108            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
4109            .await
4110            .unwrap();
4111        assert_eq!(resp.status(), StatusCode::OK);
4112    }
4113
4114    // -- compression layer --
4115
4116    #[tokio::test]
4117    async fn compression_layer_gzip_encodes_response() {
4118        use tower_http::compression::Predicate as _;
4119
4120        let big_body = "a".repeat(4096);
4121        let app = axum::Router::new()
4122            .route(
4123                "/big",
4124                axum::routing::get(move || {
4125                    let body = big_body.clone();
4126                    async move { body }
4127                }),
4128            )
4129            .layer(
4130                tower_http::compression::CompressionLayer::new()
4131                    .gzip(true)
4132                    .br(true)
4133                    .compress_when(
4134                        tower_http::compression::DefaultPredicate::new()
4135                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
4136                    ),
4137            );
4138
4139        let req = Request::builder()
4140            .uri("/big")
4141            .header(header::ACCEPT_ENCODING, "gzip")
4142            .body(Body::empty())
4143            .unwrap();
4144        let resp = app.oneshot(req).await.unwrap();
4145        assert_eq!(resp.status(), StatusCode::OK);
4146        assert_eq!(
4147            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
4148            "gzip"
4149        );
4150    }
4151
4152    // -- TlsListener handshake timeout --
4153
4154    #[tokio::test]
4155    async fn tls_handshake_timeout_reaps_idle_connections() {
4156        use tokio::io::AsyncReadExt as _;
4157
4158        let _ = rustls::crypto::ring::default_provider().install_default();
4159
4160        // Self-signed cert material on disk (TlsListener::new takes paths).
4161        let key = rcgen::KeyPair::generate().expect("generate key");
4162        let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
4163            .expect("cert params")
4164            .self_signed(&key)
4165            .expect("self-signed cert");
4166        let dir = std::env::temp_dir().join(format!(
4167            "rmcp-server-kit-hs-timeout-{}",
4168            std::time::SystemTime::now()
4169                .duration_since(std::time::UNIX_EPOCH)
4170                .expect("clock after epoch")
4171                .as_nanos()
4172        ));
4173        tokio::fs::create_dir_all(&dir).await.expect("temp dir");
4174        let cert_path = dir.join("server.crt");
4175        let key_path = dir.join("server.key");
4176        tokio::fs::write(&cert_path, cert.pem())
4177            .await
4178            .expect("write cert");
4179        tokio::fs::write(&key_path, key.serialize_pem())
4180            .await
4181            .expect("write key");
4182
4183        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
4184        let tls = TlsListener::new(
4185            listener,
4186            &cert_path,
4187            &key_path,
4188            None,
4189            None,
4190            Duration::from_millis(200),
4191            8, // custom concurrency cap: proves the plumbing end-to-end
4192        )
4193        .expect("tls listener");
4194        let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
4195
4196        // Connect and send NOTHING: the handshake worker must time out
4197        // after 200ms and drop the stream, which the client observes as
4198        // EOF or a reset well within the 2s deadline.
4199        let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
4200        let mut buf = [0_u8; 16];
4201        let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
4202            .await
4203            .expect("server must reap the idle handshake within its timeout");
4204        match read {
4205            Ok(0) | Err(_) => {} // EOF or reset: connection was dropped.
4206            Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
4207        }
4208
4209        drop(tls);
4210    }
4211}