Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::{IpAddr, SocketAddr},
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{
12    body::Body,
13    extract::{ConnectInfo, Request},
14    middleware::Next,
15    response::IntoResponse,
16};
17use rmcp::{
18    ServerHandler,
19    transport::streamable_http_server::{
20        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21    },
22};
23use rustls::RootCertStore;
24use tokio::{
25    net::TcpListener,
26    sync::{Semaphore, mpsc},
27};
28use tokio_util::sync::CancellationToken;
29
30use crate::{
31    auth::{
32        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
33        build_rate_limiter, extract_mtls_identity,
34    },
35    bounded_limiter::BoundedKeyedLimiter,
36    error::McpxError,
37    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
38    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
39};
40
41/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
42/// at the public API boundary, flattening the chain via the alternate
43/// formatter so callers see the full causal path.
44#[allow(
45    clippy::needless_pass_by_value,
46    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
47)]
48fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
49    McpxError::Startup(format!("{e:#}"))
50}
51
52/// Map a `std::io::Error` produced during server startup into a public
53/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
54/// `From` impl here because startup-phase IO errors (bind, listener) are
55/// semantically distinct from request-time IO errors and should surface
56/// the originating operation in the message.
57#[allow(
58    clippy::needless_pass_by_value,
59    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
60)]
61fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
62    McpxError::Startup(format!("{op}: {e}"))
63}
64
65/// Async readiness check callback for the `/readyz` endpoint.
66///
67/// Returns a JSON object with at least a `"ready"` boolean.
68/// When `ready` is false, the endpoint returns HTTP 503.
69pub type ReadinessCheck =
70    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
71
72/// Direct socket peer address of the current HTTP/TLS connection.
73///
74/// Inserted as a request extension into every request served by [`serve`] —
75/// on both the plain and the TLS listener — and extractable in any axum
76/// handler, including routes mounted via
77/// [`McpServerConfig::with_extra_router`] (which bypass auth/RBAC and
78/// therefore often need the peer address for their own protection, e.g.
79/// per-IP rate limiting).
80///
81/// The same address is also mirrored into
82/// [`axum::extract::ConnectInfo`]`<SocketAddr>` on the TLS listener, so
83/// third-party middleware that expects the stock axum extension (e.g.
84/// per-IP rate-limit key extractors) works unmodified under TLS.
85///
86/// # Semantics
87///
88/// - **Direct peer only.** This is the socket's remote address. Behind an
89///   L4/L7 proxy or load balancer it is the proxy's address; the framework
90///   performs **no** `X-Forwarded-For` / `Forwarded` interpretation.
91/// - **Available on HTTP and TLS** transports alike ([`serve`]).
92/// - **Absent under [`serve_stdio`]** — a stdio session has no network
93///   peer (stdio bypasses the HTTP stack entirely).
94/// - The separate Prometheus metrics listener (feature `metrics`) is a
95///   different router and does not carry this extension.
96///
97/// # Privacy
98///
99/// `PeerAddr` exposes raw peer network metadata. The framework deliberately
100/// never logs it on its own; whether to log or persist peer addresses is
101/// application policy.
102///
103/// # Example
104///
105/// ```no_run
106/// use axum::{Router, routing::get};
107/// use rmcp_server_kit::transport::{McpServerConfig, PeerAddr};
108///
109/// async fn whoami(peer: PeerAddr) -> String {
110///     peer.addr.ip().to_string()
111/// }
112///
113/// let _config = McpServerConfig::new("127.0.0.1:8443", "my-server", "1.0.0")
114///     .with_extra_router(Router::new().route("/whoami", get(whoami)));
115/// ```
116#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
117#[non_exhaustive]
118pub struct PeerAddr {
119    /// Direct socket peer of this connection.
120    pub addr: SocketAddr,
121}
122
123impl PeerAddr {
124    /// Construct a new [`PeerAddr`]. Framework-internal: downstream code
125    /// receives `PeerAddr` via request extensions and never constructs it.
126    #[must_use]
127    pub(crate) const fn new(addr: SocketAddr) -> Self {
128        Self { addr }
129    }
130}
131
132/// Extract [`PeerAddr`] from request extensions.
133///
134/// # Rejection
135///
136/// Responds `500 Internal Server Error` when the extension is missing.
137/// A missing `PeerAddr` means the handler is not running under [`serve`]
138/// (e.g. the router was mounted on a hand-rolled listener) — a wiring
139/// bug, not a client error.
140impl<S: Send + Sync> axum::extract::FromRequestParts<S> for PeerAddr {
141    type Rejection = (axum::http::StatusCode, &'static str);
142
143    async fn from_request_parts(
144        parts: &mut axum::http::request::Parts,
145        _state: &S,
146    ) -> Result<Self, Self::Rejection> {
147        parts.extensions.get::<Self>().copied().ok_or((
148            axum::http::StatusCode::INTERNAL_SERVER_ERROR,
149            "peer address unavailable: not running under rmcp-server-kit serve()",
150        ))
151    }
152}
153
154/// Resolved client IP of the current request.
155///
156/// Inserted as a request extension on every request served by [`serve`],
157/// right after [`PeerAddr`]. Equals the direct peer's IP unless
158/// **trusted-forwarder mode** is active
159/// ([`McpServerConfig::with_trusted_proxies`]) and the request arrived
160/// through a trusted proxy with a verifiable forwarding chain — in that
161/// case it is the rightmost-untrusted address from `X-Forwarded-For`
162/// (or RFC 7239 `Forwarded`, per
163/// [`McpServerConfig::with_forwarded_header`]).
164///
165/// All built-in per-IP rate limiters key by this value. [`PeerAddr`]
166/// keeps its direct-socket-peer contract unchanged; applications that
167/// need provenance can compare `ClientIp.ip` with `PeerAddr.addr.ip()`.
168///
169/// # Security
170///
171/// Resolution only ever activates when the **direct peer** is inside the
172/// operator's trusted-proxy CIDRs; every ambiguous chain (malformed or
173/// obfuscated entries, all-trusted chains, header bombs) falls back to
174/// the direct peer, never to a header value. The framework never logs
175/// this value outside rate-limit deny paths.
176#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
177#[non_exhaustive]
178pub struct ClientIp {
179    /// Resolved client IP (direct peer unless trusted-forwarder resolution applied).
180    pub ip: IpAddr,
181}
182
183impl ClientIp {
184    /// Construct a new [`ClientIp`]. Framework-internal: downstream code
185    /// receives `ClientIp` via request extensions and never constructs it.
186    #[must_use]
187    pub(crate) const fn new(ip: IpAddr) -> Self {
188        Self { ip }
189    }
190}
191
192/// Which forwarding header trusted-forwarder mode reads.
193///
194/// TOML wire values are kebab-case: `"x-forwarded-for"` (default when
195/// unset) and `"forwarded"`.
196#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Deserialize)]
197#[serde(rename_all = "kebab-case")]
198#[non_exhaustive]
199pub enum ForwardedHeaderMode {
200    /// De-facto standard `X-Forwarded-For` list (nginx, HAProxy, CDNs).
201    XForwardedFor,
202    /// RFC 7239 `Forwarded` header (`for=` parameters).
203    Forwarded,
204}
205
206/// Pre-parsed trusted-forwarder configuration captured by the
207/// peer-normalization middleware.
208struct ForwardResolver {
209    trusted: Vec<ipnet::IpNet>,
210    mode: ForwardedHeaderMode,
211}
212
213/// Per-header overrides for the OWASP security headers emitted by the
214/// global response middleware.
215///
216/// Each field follows a three-state semantic:
217///
218/// | Value         | Behaviour                                                |
219/// |---------------|----------------------------------------------------------|
220/// | `None`        | Use the built-in default (current behaviour).            |
221/// | `Some("")`    | **Omit** the header entirely from responses.             |
222/// | `Some(value)` | Emit `header: value`. Validated at config-load time.     |
223///
224/// All non-empty values are validated via
225/// [`axum::http::HeaderValue::from_str`] inside
226/// [`McpServerConfig::validate`]; invalid values fail fast before the
227/// server starts accepting traffic.
228///
229/// `Strict-Transport-Security` has an additional rule: the substring
230/// `preload` (case-insensitive) is rejected. Operators who want to
231/// commit to the HSTS preload list must do so via a future explicit
232/// builder method, not by smuggling it through this knob.
233#[derive(Debug, Clone, Default)]
234#[non_exhaustive]
235pub struct SecurityHeadersConfig {
236    /// Override for `X-Content-Type-Options`. Default: `nosniff`.
237    pub x_content_type_options: Option<String>,
238    /// Override for `X-Frame-Options`. Default: `deny`.
239    pub x_frame_options: Option<String>,
240    /// Override for `Cache-Control`. Default: `no-store, max-age=0`.
241    pub cache_control: Option<String>,
242    /// Override for `Referrer-Policy`. Default: `no-referrer`.
243    pub referrer_policy: Option<String>,
244    /// Override for `Cross-Origin-Opener-Policy`. Default: `same-origin`.
245    pub cross_origin_opener_policy: Option<String>,
246    /// Override for `Cross-Origin-Resource-Policy`. Default: `same-origin`.
247    pub cross_origin_resource_policy: Option<String>,
248    /// Override for `Cross-Origin-Embedder-Policy`. Default: `require-corp`.
249    pub cross_origin_embedder_policy: Option<String>,
250    /// Override for `Permissions-Policy`. Default:
251    /// `accelerometer=(), camera=(), geolocation=(), microphone=()`.
252    pub permissions_policy: Option<String>,
253    /// Override for `X-Permitted-Cross-Domain-Policies`. Default: `none`.
254    pub x_permitted_cross_domain_policies: Option<String>,
255    /// Override for `Content-Security-Policy`. Default:
256    /// `default-src 'none'; frame-ancestors 'none'`.
257    pub content_security_policy: Option<String>,
258    /// Override for `X-DNS-Prefetch-Control`. Default: `off`.
259    pub x_dns_prefetch_control: Option<String>,
260    /// Override for `Strict-Transport-Security`. Default (TLS only):
261    /// `max-age=63072000; includeSubDomains`. Only emitted when TLS is
262    /// active; the override is ignored on plaintext deployments. The
263    /// substring `preload` (any case) is rejected by the validator.
264    pub strict_transport_security: Option<String>,
265}
266
267/// Configuration for the MCP server.
268#[allow(
269    missing_debug_implementations,
270    reason = "contains callback/trait objects that don't impl Debug"
271)]
272#[allow(
273    clippy::struct_excessive_bools,
274    reason = "server configuration naturally has many boolean feature flags"
275)]
276#[non_exhaustive]
277pub struct McpServerConfig {
278    /// Socket address the MCP HTTP server binds to.
279    #[deprecated(
280        since = "0.13.0",
281        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
282    )]
283    pub bind_addr: String,
284    /// Server name advertised via MCP `initialize`.
285    #[deprecated(
286        since = "0.13.0",
287        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
288    )]
289    pub name: String,
290    /// Server version advertised via MCP `initialize`.
291    #[deprecated(
292        since = "0.13.0",
293        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
294    )]
295    pub version: String,
296    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
297    #[deprecated(
298        since = "0.13.0",
299        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
300    )]
301    pub tls_cert_path: Option<PathBuf>,
302    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
303    #[deprecated(
304        since = "0.13.0",
305        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
306    )]
307    pub tls_key_path: Option<PathBuf>,
308    /// Optional authentication config. When `Some` and `enabled`, auth
309    /// is enforced on `/mcp`. `/healthz` is always open.
310    #[deprecated(
311        since = "0.13.0",
312        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
313    )]
314    pub auth: Option<AuthConfig>,
315    /// Optional RBAC policy. When present and enabled, tool calls are
316    /// checked against the policy after authentication.
317    #[deprecated(
318        since = "0.13.0",
319        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
320    )]
321    pub rbac: Option<Arc<RbacPolicy>>,
322    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
323    /// When empty and `public_url` is set, the origin is auto-derived from
324    /// the public URL. When both are empty, only requests with no Origin
325    /// header are accepted.
326    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
327    #[deprecated(
328        since = "0.13.0",
329        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
330    )]
331    pub allowed_origins: Vec<String>,
332    /// Maximum tool invocations per source IP per minute.
333    /// When set, enforced on every `tools/call` request.
334    #[deprecated(
335        since = "0.13.0",
336        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
337    )]
338    pub tool_rate_limit: Option<u32>,
339    /// Burst capacity for the tool rate limiter: maximum `tools/call`
340    /// requests admitted back-to-back before the sustained
341    /// [`tool_rate_limit`](Self::tool_rate_limit) rate applies. `None`
342    /// (default) keeps governor's default of burst = rate. Requires
343    /// `tool_rate_limit` to be set; must be greater than zero.
344    #[deprecated(
345        since = "1.12.0",
346        note = "use McpServerConfig::with_tool_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
347    )]
348    pub tool_rate_limit_burst: Option<u32>,
349    /// Maximum requests per source IP per minute for routes merged via
350    /// [`with_extra_router`](Self::with_extra_router). Opt-in: `None`
351    /// (the default) installs no limiter. Startup-only (not
352    /// hot-reloadable via [`ReloadHandle`]).
353    ///
354    /// Keyed by the **direct socket peer** ([`PeerAddr`] semantics — no
355    /// `X-Forwarded-For` interpretation): behind a reverse proxy all
356    /// clients share the proxy's bucket, and IPv6 single-host address
357    /// rotation can evade per-IP keying. Treat this as an abuse speed
358    /// bump for unauthenticated application endpoints, not tenant
359    /// isolation. On limit: HTTP 429 with a plain-text body, matching
360    /// the tool/auth limiters.
361    #[deprecated(
362        since = "1.11.0",
363        note = "use McpServerConfig::with_extra_route_rate_limit(); direct field access will become pub(crate) in a future major release"
364    )]
365    pub extra_route_rate_limit: Option<u32>,
366    /// Burst capacity for the extra-route limiter: maximum requests
367    /// admitted back-to-back before the sustained
368    /// [`extra_route_rate_limit`](Self::extra_route_rate_limit) rate
369    /// applies. `None` (default) keeps governor's default of
370    /// burst = rate. Requires `extra_route_rate_limit` to be set; must
371    /// be greater than zero.
372    #[deprecated(
373        since = "1.12.0",
374        note = "use McpServerConfig::with_extra_route_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
375    )]
376    pub extra_route_rate_limit_burst: Option<u32>,
377    /// Trusted reverse-proxy networks (CIDRs or bare IPs) for
378    /// **trusted-forwarder mode**. Empty (default) = mode off: every
379    /// limiter keys by the direct socket peer. Nonempty = requests whose
380    /// direct peer is inside one of these networks have their client IP
381    /// resolved from the forwarding header (rightmost-untrusted walk);
382    /// see [`ClientIp`]. Only enable when **all** ingress paths traverse
383    /// the listed proxies. Startup-only.
384    #[deprecated(
385        since = "1.13.0",
386        note = "use McpServerConfig::with_trusted_proxies(); direct field access will become pub(crate) in a future major release"
387    )]
388    pub trusted_proxies: Vec<String>,
389    /// Which forwarding header trusted-forwarder mode reads. `None`
390    /// (default) = `X-Forwarded-For`. Setting this requires
391    /// [`trusted_proxies`](Self::trusted_proxies) to be nonempty
392    /// (validated). Startup-only.
393    #[deprecated(
394        since = "1.13.0",
395        note = "use McpServerConfig::with_forwarded_header(); direct field access will become pub(crate) in a future major release"
396    )]
397    pub forwarded_header: Option<ForwardedHeaderMode>,
398    /// Optional readiness probe for `/readyz`.
399    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
400    #[deprecated(
401        since = "0.13.0",
402        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
403    )]
404    pub readiness_check: Option<ReadinessCheck>,
405    /// Maximum request body size in bytes. Default: 1 MiB.
406    /// Protects against oversized payloads causing OOM.
407    #[deprecated(
408        since = "0.13.0",
409        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
410    )]
411    pub max_request_body: usize,
412    /// Request processing timeout. Default: 120s.
413    /// Requests exceeding this duration receive 408 Request Timeout.
414    #[deprecated(
415        since = "0.13.0",
416        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
417    )]
418    pub request_timeout: Duration,
419    /// Graceful shutdown timeout. Default: 30s.
420    /// After the shutdown signal, in-flight requests have this long to finish.
421    #[deprecated(
422        since = "0.13.0",
423        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
424    )]
425    pub shutdown_timeout: Duration,
426    /// Idle timeout for MCP sessions. Sessions with no activity for this
427    /// duration are closed automatically. Default: 20 minutes.
428    #[deprecated(
429        since = "0.13.0",
430        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
431    )]
432    pub session_idle_timeout: Duration,
433    /// Interval for SSE keep-alive pings. Prevents proxies and load
434    /// balancers from killing idle connections. Default: 15 seconds.
435    #[deprecated(
436        since = "0.13.0",
437        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
438    )]
439    pub sse_keep_alive: Duration,
440    /// Callback invoked once the server is built, delivering a
441    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
442    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
443    #[deprecated(
444        since = "0.13.0",
445        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
446    )]
447    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
448    /// Additional application-specific routes merged into the top-level
449    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
450    /// so the application is responsible for its own auth on them.
451    /// Handlers can extract [`PeerAddr`] (or
452    /// [`axum::extract::ConnectInfo`]`<SocketAddr>` for third-party
453    /// middleware compatibility) regardless of whether TLS is enabled.
454    #[deprecated(
455        since = "0.13.0",
456        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
457    )]
458    pub extra_router: Option<axum::Router>,
459    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
460    /// When set, OAuth metadata endpoints advertise this URL instead of
461    /// the listen address. Required when binding `0.0.0.0` behind a
462    /// reverse proxy or inside a container.
463    #[deprecated(
464        since = "0.13.0",
465        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
466    )]
467    pub public_url: Option<String>,
468    /// Log inbound HTTP request headers at DEBUG level.
469    /// Sensitive values remain redacted.
470    #[deprecated(
471        since = "0.13.0",
472        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
473    )]
474    pub log_request_headers: bool,
475    /// Enable gzip/br response compression on MCP responses.
476    /// Defaults to `false` to preserve existing behaviour.
477    #[deprecated(
478        since = "0.13.0",
479        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
480    )]
481    pub compression_enabled: bool,
482    /// Minimum response body size (in bytes) before compression kicks in.
483    /// Only used when `compression_enabled` is true. Default: 1024.
484    #[deprecated(
485        since = "0.13.0",
486        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
487    )]
488    pub compression_min_size: u16,
489    /// Global cap on in-flight HTTP requests across the whole server.
490    /// When `Some`, requests over the cap receive 503 Service Unavailable
491    /// via `tower::load_shed`. Default: `None` (unlimited).
492    #[deprecated(
493        since = "0.13.0",
494        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
495    )]
496    pub max_concurrent_requests: Option<usize>,
497    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
498    /// configured and `enabled`. Default: `false`.
499    #[deprecated(
500        since = "0.13.0",
501        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
502    )]
503    pub admin_enabled: bool,
504    /// RBAC role required to access admin endpoints. Default: `"admin"`.
505    #[deprecated(
506        since = "0.13.0",
507        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
508    )]
509    pub admin_role: String,
510    /// Enable Prometheus metrics endpoint on a separate listener.
511    /// Requires the `metrics` crate feature.
512    #[cfg(feature = "metrics")]
513    #[deprecated(
514        since = "0.13.0",
515        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
516    )]
517    pub metrics_enabled: bool,
518    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
519    #[cfg(feature = "metrics")]
520    #[deprecated(
521        since = "0.13.0",
522        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
523    )]
524    pub metrics_bind: String,
525    /// Per-header overrides for the OWASP security headers emitted by
526    /// the global response middleware. See [`SecurityHeadersConfig`]
527    /// for the three-state semantic and validation rules.
528    #[deprecated(
529        since = "1.5.0",
530        note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
531    )]
532    pub security_headers: SecurityHeadersConfig,
533    /// Per-handshake deadline on the TLS accept path. Idle or slow-loris
534    /// connections are dropped once it elapses. Default: 10s.
535    ///
536    /// Startup-only: bound at listener construction, NOT hot-reloadable
537    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
538    #[deprecated(
539        since = "1.9.0",
540        note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
541    )]
542    pub tls_handshake_timeout: Duration,
543    /// Cap on concurrently in-flight TLS handshakes. At saturation the
544    /// acceptor stops pulling new connections from the kernel backlog
545    /// (backpressure) instead of accepting and dropping. Default: 256.
546    ///
547    /// Startup-only: bound at listener construction, NOT hot-reloadable
548    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
549    #[deprecated(
550        since = "1.9.0",
551        note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
552    )]
553    pub max_concurrent_tls_handshakes: usize,
554}
555
556/// Marker that wraps a value proven to satisfy its validation
557/// contract.
558///
559/// The only way to obtain `Validated<McpServerConfig>` is by calling
560/// [`McpServerConfig::validate`], which is the contract enforced at
561/// the type level by [`serve`] and [`serve_with_listener`]. The
562/// inner field is private, so downstream code cannot bypass
563/// validation by hand-constructing the wrapper.
564///
565/// Use [`Validated::as_inner`] for read-only borrowing. To mutate,
566/// recover the raw value with [`Validated::into_inner`] and
567/// re-validate.
568///
569/// # Example
570///
571/// ```no_run
572/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
573/// use rmcp::handler::server::ServerHandler;
574/// use rmcp::model::{ServerCapabilities, ServerInfo};
575///
576/// #[derive(Clone)]
577/// struct H;
578/// impl ServerHandler for H {
579///     fn get_info(&self) -> ServerInfo {
580///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
581///     }
582/// }
583///
584/// # async fn example() -> rmcp_server_kit::Result<()> {
585/// let config: Validated<McpServerConfig> =
586///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
587/// serve(config, || H).await
588/// # }
589/// ```
590///
591/// Forgetting `.validate()?` is a compile error:
592///
593/// ```compile_fail
594/// use rmcp_server_kit::transport::{McpServerConfig, serve};
595/// use rmcp::handler::server::ServerHandler;
596/// use rmcp::model::{ServerCapabilities, ServerInfo};
597///
598/// #[derive(Clone)]
599/// struct H;
600/// impl ServerHandler for H {
601///     fn get_info(&self) -> ServerInfo {
602///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
603///     }
604/// }
605///
606/// # async fn example() -> rmcp_server_kit::Result<()> {
607/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
608/// // Missing `.validate()?` -> mismatched types: expected
609/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
610/// serve(config, || H).await
611/// # }
612/// ```
613#[allow(
614    missing_debug_implementations,
615    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
616)]
617pub struct Validated<T>(T);
618
619impl<T> std::fmt::Debug for Validated<T> {
620    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
621        f.debug_struct("Validated").finish_non_exhaustive()
622    }
623}
624
625impl<T> Validated<T> {
626    /// Borrow the inner value.
627    #[must_use]
628    pub fn as_inner(&self) -> &T {
629        &self.0
630    }
631
632    /// Recover the raw value, discarding the validation proof.
633    ///
634    /// Re-validate before re-using the value with [`serve`] or
635    /// [`serve_with_listener`].
636    #[must_use]
637    pub fn into_inner(self) -> T {
638        self.0
639    }
640}
641
642#[allow(
643    deprecated,
644    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
645)]
646impl McpServerConfig {
647    /// Create a new server configuration with the given bind address,
648    /// server name, and version. All other fields use safe defaults.
649    ///
650    /// Use the chainable `with_*` / `enable_*` builder methods to
651    /// customize. Call [`McpServerConfig::validate`] to obtain a
652    /// [`Validated<McpServerConfig>`] proof token, which is required by
653    /// [`serve`] and [`serve_with_listener`].
654    #[must_use]
655    pub fn new(
656        bind_addr: impl Into<String>,
657        name: impl Into<String>,
658        version: impl Into<String>,
659    ) -> Self {
660        Self {
661            bind_addr: bind_addr.into(),
662            name: name.into(),
663            version: version.into(),
664            tls_cert_path: None,
665            tls_key_path: None,
666            auth: None,
667            rbac: None,
668            allowed_origins: Vec::new(),
669            tool_rate_limit: None,
670            readiness_check: None,
671            max_request_body: 1024 * 1024,
672            request_timeout: Duration::from_mins(2),
673            shutdown_timeout: Duration::from_secs(30),
674            session_idle_timeout: Duration::from_mins(20),
675            sse_keep_alive: Duration::from_secs(15),
676            on_reload_ready: None,
677            extra_router: None,
678            public_url: None,
679            log_request_headers: false,
680            compression_enabled: false,
681            compression_min_size: 1024,
682            max_concurrent_requests: None,
683            admin_enabled: false,
684            admin_role: "admin".to_owned(),
685            #[cfg(feature = "metrics")]
686            metrics_enabled: false,
687            #[cfg(feature = "metrics")]
688            metrics_bind: "127.0.0.1:9090".into(),
689            security_headers: SecurityHeadersConfig::default(),
690            tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
691            max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
692            extra_route_rate_limit: None,
693            tool_rate_limit_burst: None,
694            extra_route_rate_limit_burst: None,
695            trusted_proxies: Vec::new(),
696            forwarded_header: None,
697        }
698    }
699
700    // ---------------------------------------------------------------
701    // Builder methods (fluent, consume + return self).
702    //
703    // Each method is `#[must_use]` because dropping the returned
704    // `McpServerConfig` discards the configuration change.
705    // ---------------------------------------------------------------
706
707    /// Attach an authentication configuration. Required for
708    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
709    #[must_use]
710    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
711        self.auth = Some(auth);
712        self
713    }
714
715    /// Override one or more of the OWASP security headers emitted on
716    /// every response. See [`SecurityHeadersConfig`] for the three-state
717    /// semantic (`None` = default, `Some("")` = omit, `Some(v)` =
718    /// override). Values are validated by [`Self::validate`].
719    #[must_use]
720    pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
721        self.security_headers = headers;
722        self
723    }
724
725    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
726    /// final port is only known after pre-binding an ephemeral listener
727    /// (tests, dynamic-port deployments).
728    #[must_use]
729    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
730        self.bind_addr = addr.into();
731        self
732    }
733
734    /// Attach an RBAC policy. Tool calls are checked against the policy
735    /// after authentication.
736    #[must_use]
737    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
738        self.rbac = Some(rbac);
739        self
740    }
741
742    /// Configure TLS by providing the certificate and private key paths
743    /// (PEM). Both must be readable at startup. Without this call, the
744    /// server runs plain HTTP.
745    #[must_use]
746    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
747        self.tls_cert_path = Some(cert_path.into());
748        self.tls_key_path = Some(key_path.into());
749        self
750    }
751
752    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
753    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
754    /// a container so OAuth metadata and auto-derived origins resolve correctly.
755    #[must_use]
756    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
757        self.public_url = Some(url.into());
758        self
759    }
760
761    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
762    /// When empty and [`with_public_url`](Self::with_public_url) is set,
763    /// the origin is auto-derived.
764    #[must_use]
765    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
766    where
767        I: IntoIterator<Item = S>,
768        S: Into<String>,
769    {
770        self.allowed_origins = origins.into_iter().map(Into::into).collect();
771        self
772    }
773
774    /// Merge an additional axum router at the top level. Routes added
775    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
776    /// for its own protection.
777    ///
778    /// To support that protection (e.g. per-IP rate limiting on
779    /// unauthenticated endpoints), every request served by [`serve`]
780    /// carries the client peer address regardless of whether TLS is
781    /// enabled: extract the framework-owned [`PeerAddr`] in your
782    /// handlers, or rely on [`axum::extract::ConnectInfo`]`<SocketAddr>`
783    /// for stock third-party middleware (e.g. per-IP rate-limit key
784    /// extractors). Neither extension exists under [`serve_stdio`],
785    /// which has no network peer.
786    #[must_use]
787    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
788        self.extra_router = Some(router);
789        self
790    }
791
792    /// Install an async readiness probe for `/readyz`. Without this call,
793    /// `/readyz` mirrors `/healthz` (always 200 OK).
794    #[must_use]
795    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
796        self.readiness_check = Some(check);
797        self
798    }
799
800    /// Override the maximum request body (bytes). Must be `> 0`.
801    /// Default: 1 MiB.
802    #[must_use]
803    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
804        self.max_request_body = bytes;
805        self
806    }
807
808    /// Override the per-request processing timeout. Default: 2 minutes.
809    #[must_use]
810    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
811        self.request_timeout = timeout;
812        self
813    }
814
815    /// Override the graceful shutdown grace period. Default: 30 seconds.
816    #[must_use]
817    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
818        self.shutdown_timeout = timeout;
819        self
820    }
821
822    /// Override the MCP session idle timeout. Default: 20 minutes.
823    #[must_use]
824    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
825        self.session_idle_timeout = timeout;
826        self
827    }
828
829    /// Override the SSE keep-alive interval. Default: 15 seconds.
830    #[must_use]
831    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
832        self.sse_keep_alive = interval;
833        self
834    }
835
836    /// Cap the global number of in-flight HTTP requests via
837    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
838    /// Default: unlimited.
839    #[must_use]
840    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
841        self.max_concurrent_requests = Some(limit);
842        self
843    }
844
845    /// Override the per-handshake deadline on the TLS accept path.
846    /// Idle or slow-loris connections are dropped once it elapses.
847    /// Default: 10s. Must be greater than zero.
848    ///
849    /// Startup-only: the value is bound at listener construction and is
850    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
851    /// is configured via [`Self::with_tls`].
852    #[must_use]
853    pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
854        self.tls_handshake_timeout = timeout;
855        self
856    }
857
858    /// Override the cap on concurrently in-flight TLS handshakes. At
859    /// saturation the acceptor stops pulling new connections from the
860    /// kernel backlog (backpressure) instead of accepting and dropping.
861    /// Default: 256. Must be greater than zero.
862    ///
863    /// Startup-only: the value is bound at listener construction and is
864    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
865    /// is configured via [`Self::with_tls`].
866    #[must_use]
867    pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
868        self.max_concurrent_tls_handshakes = limit;
869        self
870    }
871
872    /// Cap tool invocations per source IP per minute. Enforced on every
873    /// `tools/call` request.
874    #[must_use]
875    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
876        self.tool_rate_limit = Some(per_minute);
877        self
878    }
879
880    /// Cap requests per source IP per minute on routes merged via
881    /// [`with_extra_router`](Self::with_extra_router) — the natural
882    /// protection for unauthenticated application endpoints (OAuth
883    /// callbacks, registration, …) that bypass auth/RBAC by design.
884    ///
885    /// Must be greater than zero (validated by
886    /// [`validate`](Self::validate)). Startup-only. See the
887    /// `extra_route_rate_limit` field docs for keying semantics and
888    /// caveats (direct peer only, IPv6 rotation, proxy collapse,
889    /// bounded-memory shared-fate under key spray).
890    #[must_use]
891    pub fn with_extra_route_rate_limit(mut self, per_minute: u32) -> Self {
892        self.extra_route_rate_limit = Some(per_minute);
893        self
894    }
895
896    /// Set the burst capacity for the tool rate limiter (bucket size;
897    /// the sustained rate stays [`with_tool_rate_limit`](Self::with_tool_rate_limit)).
898    /// Requires the tool rate limit to be set; must be greater than zero
899    /// (both validated by [`validate`](Self::validate)).
900    #[must_use]
901    pub fn with_tool_rate_limit_burst(mut self, burst: u32) -> Self {
902        self.tool_rate_limit_burst = Some(burst);
903        self
904    }
905
906    /// Set the burst capacity for the extra-route rate limiter (bucket
907    /// size; the sustained rate stays
908    /// [`with_extra_route_rate_limit`](Self::with_extra_route_rate_limit)).
909    /// Requires the extra-route rate limit to be set; must be greater
910    /// than zero (both validated by [`validate`](Self::validate)).
911    #[must_use]
912    pub fn with_extra_route_rate_limit_burst(mut self, burst: u32) -> Self {
913        self.extra_route_rate_limit_burst = Some(burst);
914        self
915    }
916
917    /// Enable **trusted-forwarder mode**: requests whose direct peer is
918    /// inside one of these networks (CIDRs or bare IPs) have their
919    /// client IP resolved from the forwarding header via the
920    /// rightmost-untrusted walk; all per-IP rate limiters then key by
921    /// the resolved [`ClientIp`]. Headers from peers outside these
922    /// networks are ignored entirely.
923    ///
924    /// Only enable when **all** ingress paths traverse the listed
925    /// proxies — otherwise direct clients keep their own buckets and
926    /// proxied clients collapse into the proxy's. Entries are validated
927    /// at [`validate`](Self::validate) time. Startup-only.
928    #[must_use]
929    pub fn with_trusted_proxies<I, S>(mut self, proxies: I) -> Self
930    where
931        I: IntoIterator<Item = S>,
932        S: Into<String>,
933    {
934        self.trusted_proxies = proxies.into_iter().map(Into::into).collect();
935        self
936    }
937
938    /// Select which forwarding header trusted-forwarder mode reads
939    /// (default: `X-Forwarded-For`). Requires
940    /// [`with_trusted_proxies`](Self::with_trusted_proxies) to be set
941    /// (validated).
942    #[must_use]
943    pub fn with_forwarded_header(mut self, mode: ForwardedHeaderMode) -> Self {
944        self.forwarded_header = Some(mode);
945        self
946    }
947
948    /// Register a callback that receives the [`ReloadHandle`] after the
949    /// server is built. Use it to wire SIGHUP-style hot reloads of API
950    /// keys and RBAC policy.
951    #[must_use]
952    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
953    where
954        F: FnOnce(ReloadHandle) + Send + 'static,
955    {
956        self.on_reload_ready = Some(Box::new(callback));
957        self
958    }
959
960    /// Enable gzip/brotli response compression on MCP responses.
961    /// `min_size` is the smallest body size (bytes) eligible for
962    /// compression. Default min size: 1024.
963    #[must_use]
964    pub fn enable_compression(mut self, min_size: u16) -> Self {
965        self.compression_enabled = true;
966        self.compression_min_size = min_size;
967        self
968    }
969
970    /// Enable `/admin/*` diagnostic endpoints. Requires
971    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
972    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
973    /// role gate (default: `"admin"`).
974    #[must_use]
975    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
976        self.admin_enabled = true;
977        self.admin_role = role.into();
978        self
979    }
980
981    /// Log inbound HTTP request headers at DEBUG level. Sensitive
982    /// values remain redacted by the logging layer.
983    #[must_use]
984    pub fn enable_request_header_logging(mut self) -> Self {
985        self.log_request_headers = true;
986        self
987    }
988
989    /// Enable the Prometheus metrics listener on `bind` (e.g.
990    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
991    #[cfg(feature = "metrics")]
992    #[must_use]
993    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
994        self.metrics_enabled = true;
995        self.metrics_bind = bind.into();
996        self
997    }
998
999    /// Validate the configuration and consume `self`, returning a
1000    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
1001    /// and [`serve_with_listener`]. This is the only way to construct
1002    /// `Validated<McpServerConfig>`, so the type system guarantees
1003    /// validation has run before the server starts.
1004    ///
1005    /// Checks:
1006    ///
1007    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
1008    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
1009    ///    be unset.
1010    /// 3. `bind_addr` must parse as a [`SocketAddr`].
1011    /// 4. `public_url`, when set, must start with `http://` or `https://`.
1012    /// 5. Each entry in `allowed_origins` must start with `http://` or
1013    ///    `https://`.
1014    /// 6. `max_request_body` must be greater than zero.
1015    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
1016    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
1017    ///    `proxy.token_url`, `proxy.introspection_url`,
1018    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
1019    ///    and use the `https` scheme. Set
1020    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
1021    ///    targets (strongly discouraged in production - see the field-level
1022    ///    docs for the threat model).
1023    ///
1024    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
1025    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
1026    ///
1027    /// # Errors
1028    ///
1029    /// Returns [`McpxError::Config`] with a human-readable message on
1030    /// the first validation failure.
1031    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
1032        self.check()?;
1033        Ok(Validated(self))
1034    }
1035
1036    /// Validate the burst knobs: every burst must be greater than zero
1037    /// when set, and the two top-level bursts require their base limiter
1038    /// to be configured. The auth bursts
1039    /// (`RateLimitConfig::{burst, pre_auth_burst}`) have no orphan rule:
1040    /// their base rates always resolve (`max_attempts_per_minute` is
1041    /// mandatory; the pre-auth base derives from it when unset).
1042    fn check_burst_knobs(&self) -> Result<(), McpxError> {
1043        if self.tool_rate_limit_burst == Some(0) {
1044            return Err(McpxError::Config(
1045                "tool_rate_limit_burst must be greater than zero".into(),
1046            ));
1047        }
1048        if self.extra_route_rate_limit_burst == Some(0) {
1049            return Err(McpxError::Config(
1050                "extra_route_rate_limit_burst must be greater than zero".into(),
1051            ));
1052        }
1053        if self.tool_rate_limit_burst.is_some() && self.tool_rate_limit.is_none() {
1054            return Err(McpxError::Config(
1055                "tool_rate_limit_burst requires tool_rate_limit to be set".into(),
1056            ));
1057        }
1058        if self.extra_route_rate_limit_burst.is_some() && self.extra_route_rate_limit.is_none() {
1059            return Err(McpxError::Config(
1060                "extra_route_rate_limit_burst requires extra_route_rate_limit to be set".into(),
1061            ));
1062        }
1063        if let Some(rl) = self.auth.as_ref().and_then(|a| a.rate_limit.as_ref()) {
1064            if rl.burst == Some(0) {
1065                return Err(McpxError::Config(
1066                    "auth rate_limit.burst must be greater than zero".into(),
1067                ));
1068            }
1069            if rl.pre_auth_burst == Some(0) {
1070                return Err(McpxError::Config(
1071                    "auth rate_limit.pre_auth_burst must be greater than zero".into(),
1072                ));
1073            }
1074        }
1075        Ok(())
1076    }
1077
1078    /// Validate the trusted-forwarder knobs: every `trusted_proxies`
1079    /// entry must parse as a CIDR (`ipnet::IpNet`) or a bare IP
1080    /// (normalized to a host network), and `forwarded_header` requires a
1081    /// nonempty proxy list (fail-fast over a silent no-op).
1082    fn check_trusted_forwarder(&self) -> Result<(), McpxError> {
1083        for entry in &self.trusted_proxies {
1084            if parse_proxy_net(entry).is_none() {
1085                return Err(McpxError::Config(format!(
1086                    "trusted_proxies entry {entry:?} is neither a CIDR nor an IP address"
1087                )));
1088            }
1089        }
1090        if self.forwarded_header.is_some() && self.trusted_proxies.is_empty() {
1091            return Err(McpxError::Config(
1092                "forwarded_header requires trusted_proxies to be nonempty".into(),
1093            ));
1094        }
1095        Ok(())
1096    }
1097
1098    /// Run the validation checks without consuming `self`. Used by
1099    /// internal call sites (e.g. tests) that need to inspect a config
1100    /// without taking ownership.
1101    fn check(&self) -> Result<(), McpxError> {
1102        // 1. admin <-> auth dependency. Mirrors the runtime check in
1103        //    `build_app_router`: admin endpoints require an auth state,
1104        //    which is built only when `auth` is `Some` *and* `enabled`.
1105        if self.admin_enabled {
1106            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
1107            if !auth_enabled {
1108                return Err(McpxError::Config(
1109                    "admin_enabled=true requires auth to be configured and enabled".into(),
1110                ));
1111            }
1112        }
1113
1114        // 2. TLS cert / key must be paired
1115        match (&self.tls_cert_path, &self.tls_key_path) {
1116            (Some(_), None) => {
1117                return Err(McpxError::Config(
1118                    "tls_cert_path is set but tls_key_path is missing".into(),
1119                ));
1120            }
1121            (None, Some(_)) => {
1122                return Err(McpxError::Config(
1123                    "tls_key_path is set but tls_cert_path is missing".into(),
1124                ));
1125            }
1126            _ => {}
1127        }
1128
1129        // 3. bind_addr parses
1130        if self.bind_addr.parse::<SocketAddr>().is_err() {
1131            return Err(McpxError::Config(format!(
1132                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
1133                self.bind_addr
1134            )));
1135        }
1136
1137        // 4. public_url scheme
1138        if let Some(ref url) = self.public_url
1139            && !(url.starts_with("http://") || url.starts_with("https://"))
1140        {
1141            return Err(McpxError::Config(format!(
1142                "public_url {url:?} must start with http:// or https://"
1143            )));
1144        }
1145
1146        // 5. allowed_origins scheme
1147        for origin in &self.allowed_origins {
1148            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
1149                return Err(McpxError::Config(format!(
1150                    "allowed_origins entry {origin:?} must start with http:// or https://"
1151                )));
1152            }
1153        }
1154
1155        // 6. max_request_body > 0
1156        if self.max_request_body == 0 {
1157            return Err(McpxError::Config(
1158                "max_request_body must be greater than zero".into(),
1159            ));
1160        }
1161
1162        // 6b. extra_route_rate_limit, when set, must be > 0. Unlike the
1163        // legacy tool_rate_limit (which clamps 0 to its default at
1164        // construction), new knobs fail fast on nonsensical values.
1165        if self.extra_route_rate_limit == Some(0) {
1166            return Err(McpxError::Config(
1167                "extra_route_rate_limit must be greater than zero".into(),
1168            ));
1169        }
1170
1171        // 6c. Burst knobs (extracted helper).
1172        self.check_burst_knobs()?;
1173
1174        // 6d. Trusted-forwarder knobs (extracted helper).
1175        self.check_trusted_forwarder()?;
1176
1177        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
1178        #[cfg(feature = "oauth")]
1179        if let Some(auth_cfg) = &self.auth
1180            && let Some(oauth_cfg) = &auth_cfg.oauth
1181        {
1182            oauth_cfg.validate()?;
1183        }
1184
1185        // 8. Security-header overrides parse as valid HTTP header values,
1186        //    and HSTS does not smuggle in a `preload` directive.
1187        validate_security_headers(&self.security_headers)?;
1188
1189        // 9. max_concurrent_requests must be > 0 when set. Zero would
1190        //    deadlock the global concurrency limiter and reject every
1191        //    request. Mirrors the TOML-side check in `src/config.rs`.
1192        if let Some(0) = self.max_concurrent_requests {
1193            return Err(McpxError::Config(
1194                "max_concurrent_requests must be greater than zero when set".into(),
1195            ));
1196        }
1197
1198        // 10. Auth rate-limit `max_tracked_keys` must be > 0. A zero cap
1199        //     would force `BoundedKeyedLimiter` to evict on every insert
1200        //     and effectively disable rate limiting.
1201        if let Some(auth_cfg) = &self.auth
1202            && let Some(rl) = &auth_cfg.rate_limit
1203            && rl.max_tracked_keys == 0
1204        {
1205            return Err(McpxError::Config(
1206                "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
1207            ));
1208        }
1209
1210        // 11. tls_handshake_timeout must be > 0. A zero deadline would
1211        //     reap every handshake before it could complete, rejecting
1212        //     all TLS connections. Mirrors the TOML-side check in
1213        //     `src/config.rs`.
1214        if self.tls_handshake_timeout == Duration::ZERO {
1215            return Err(McpxError::Config(
1216                "tls_handshake_timeout must be greater than zero".into(),
1217            ));
1218        }
1219
1220        // 12. max_concurrent_tls_handshakes must be > 0. A zero-permit
1221        //     semaphore would never admit a handshake, deadlocking the
1222        //     TLS accept path. Mirrors the TOML-side check in
1223        //     `src/config.rs`.
1224        if self.max_concurrent_tls_handshakes == 0 {
1225            return Err(McpxError::Config(
1226                "max_concurrent_tls_handshakes must be greater than zero".into(),
1227            ));
1228        }
1229
1230        Ok(())
1231    }
1232}
1233
1234/// Handle for hot-reloading server configuration without restart.
1235///
1236/// Obtained via [`McpServerConfig::on_reload_ready`].
1237/// All swap operations are lock-free and wait-free -- in-flight requests
1238/// finish with the old values while new requests see the update immediately.
1239#[allow(
1240    missing_debug_implementations,
1241    reason = "contains Arc<AuthState> with non-Debug fields"
1242)]
1243pub struct ReloadHandle {
1244    auth: Option<Arc<AuthState>>,
1245    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
1246    crl_set: Option<Arc<CrlSet>>,
1247}
1248
1249impl ReloadHandle {
1250    /// Atomically replace the API key list used by the auth middleware.
1251    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
1252        if let Some(ref auth) = self.auth {
1253            auth.reload_keys(keys);
1254        }
1255    }
1256
1257    /// Atomically replace the RBAC policy used by the RBAC middleware.
1258    pub fn reload_rbac(&self, policy: RbacPolicy) {
1259        if let Some(ref rbac) = self.rbac {
1260            rbac.store(Arc::new(policy));
1261            tracing::info!("RBAC policy reloaded");
1262        }
1263    }
1264
1265    /// Force an immediate refresh of all cached mTLS CRLs.
1266    ///
1267    /// # Errors
1268    ///
1269    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
1270    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
1271        let Some(ref crl_set) = self.crl_set else {
1272            return Err(McpxError::Config(
1273                "CRL refresh requested but mTLS CRL support is not configured".into(),
1274            ));
1275        };
1276
1277        crl_set.force_refresh().await
1278    }
1279}
1280
1281/// Generic MCP HTTP server.
1282///
1283/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
1284/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
1285/// with TLS (rustls). Optionally supports mTLS client certificate auth.
1286///
1287/// # Errors
1288///
1289/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
1290/// or the server fails.
1291// NOTE: cognitive complexity reduced from 111/25 to 83/25 by
1292// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
1293// Remaining flow is a linear router builder: middleware layering, feature-
1294// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
1295// would require threading many `&mut Router` helpers and hurt readability
1296// of the layer order (which is security-relevant and must stay visible).
1297#[allow(
1298    clippy::too_many_lines,
1299    clippy::cognitive_complexity,
1300    reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
1301)]
1302/// Internal bundle of values produced by [`build_app_router`] and
1303/// consumed by [`serve`] / [`serve_with_listener`] when driving the
1304/// HTTP listener.
1305struct AppRunParams {
1306    /// TLS cert/key paths when TLS is configured.
1307    tls_paths: Option<(PathBuf, PathBuf)>,
1308    /// Per-handshake deadline on the TLS accept path.
1309    tls_handshake_timeout: Duration,
1310    /// Cap on concurrently in-flight TLS handshakes.
1311    max_concurrent_tls_handshakes: usize,
1312    /// mTLS configuration when mutual-TLS auth is enabled.
1313    mtls_config: Option<MtlsConfig>,
1314    /// Graceful shutdown drain window.
1315    shutdown_timeout: Duration,
1316    /// Shared auth state used by hot-reload callbacks.
1317    auth_state: Option<Arc<AuthState>>,
1318    /// Hot-reloadable RBAC state used by reload callbacks.
1319    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1320    /// Optional callback that receives the final [`ReloadHandle`].
1321    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1322    /// Server-internal cancellation token. Cancelled by [`run_server`]
1323    /// once the shutdown trigger fires (so rmcp's child token also
1324    /// fires, terminating in-flight MCP sessions).
1325    ct: CancellationToken,
1326    /// `"http"` or `"https"` -- used only for boot-time logging.
1327    scheme: &'static str,
1328    /// Server name -- used only for boot-time logging.
1329    name: String,
1330}
1331
1332/// Build the full application axum [`axum::Router`] (MCP route +
1333/// middleware stack + admin + OAuth + health endpoints + security
1334/// headers + CORS + compression + concurrency limit + origin check)
1335/// and the [`AppRunParams`] needed to drive it.
1336///
1337/// This is the shared core of [`serve`] and [`serve_with_listener`].
1338/// It performs *no* network I/O: callers are responsible for binding
1339/// (or accepting a pre-bound) [`TcpListener`] and invoking
1340/// [`run_server`].
1341#[allow(
1342    clippy::cognitive_complexity,
1343    reason = "router assembly is intrinsically sequential; splitting harms readability"
1344)]
1345#[allow(
1346    deprecated,
1347    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
1348)]
1349fn build_app_router<H, F>(
1350    mut config: McpServerConfig,
1351    handler_factory: F,
1352) -> anyhow::Result<(axum::Router, AppRunParams)>
1353where
1354    H: ServerHandler + 'static,
1355    F: Fn() -> H + Send + Sync + Clone + 'static,
1356{
1357    let ct = CancellationToken::new();
1358
1359    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
1360    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
1361
1362    let mcp_service = StreamableHttpService::new(
1363        move || Ok(handler_factory()),
1364        {
1365            let mut mgr = LocalSessionManager::default();
1366            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
1367            mgr.into()
1368        },
1369        StreamableHttpServerConfig::default()
1370            .with_allowed_hosts(allowed_hosts)
1371            .with_sse_keep_alive(Some(config.sse_keep_alive))
1372            .with_cancellation_token(ct.child_token()),
1373    );
1374
1375    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
1376    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1377
1378    // Build auth state eagerly when auth is configured so we can wire both
1379    // the auth middleware *and* the optional admin router against the same
1380    // state. The middleware itself is installed further down in layer order.
1381    let auth_state: Option<Arc<AuthState>> = match config.auth {
1382        Some(ref auth_config) if auth_config.enabled => {
1383            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1384            let pre_auth_limiter = auth_config
1385                .rate_limit
1386                .as_ref()
1387                .map(crate::auth::build_pre_auth_limiter);
1388
1389            #[cfg(feature = "oauth")]
1390            let jwks_cache = auth_config
1391                .oauth
1392                .as_ref()
1393                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1394                .transpose()
1395                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1396
1397            Some(Arc::new(AuthState {
1398                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1399                rate_limiter,
1400                pre_auth_limiter,
1401                #[cfg(feature = "oauth")]
1402                jwks_cache,
1403                seen_identities: crate::auth::SeenIdentitySet::new(),
1404                counters: crate::auth::AuthCounters::default(),
1405            }))
1406        }
1407        _ => None,
1408    };
1409
1410    // Build the RBAC policy swap early so the admin router and the later
1411    // RBAC middleware layer share the same hot-reloadable state.
1412    let rbac_swap = Arc::new(ArcSwap::new(
1413        config
1414            .rbac
1415            .clone()
1416            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1417    ));
1418
1419    // Optional /admin/* diagnostic routes. Merged BEFORE the
1420    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
1421    if config.admin_enabled {
1422        let Some(ref auth_state_ref) = auth_state else {
1423            return Err(anyhow::anyhow!(
1424                "admin_enabled=true requires auth to be configured and enabled"
1425            ));
1426        };
1427        let admin_state = crate::admin::AdminState {
1428            started_at: std::time::Instant::now(),
1429            name: config.name.clone(),
1430            version: config.version.clone(),
1431            auth: Some(Arc::clone(auth_state_ref)),
1432            rbac: Arc::clone(&rbac_swap),
1433        };
1434        let admin_cfg = crate::admin::AdminConfig {
1435            role: config.admin_role.clone(),
1436        };
1437        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1438        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1439    }
1440
1441    // ----- Middleware order (CRITICAL: read carefully) ------------------
1442    //
1443    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
1444    // added is the OUTERMOST (runs first on a request). To achieve a
1445    // request-time flow of:
1446    //
1447    //   body-limit -> timeout -> auth -> rbac -> handler
1448    //
1449    // we add layers in the REVERSE order:
1450    //
1451    //   1. RBAC               (innermost, runs last before handler)
1452    //   2. auth               (parses identity, sets extension for RBAC)
1453    //   3. timeout            (bounds total request time)
1454    //   4. body-limit         (outermost on /mcp; caps payload before
1455    //                          anything else reads/buffers it)
1456    //
1457    // Origin validation is installed on the OUTER router (after the
1458    // /mcp router is merged in), so it also protects /healthz, /readyz,
1459    // /version, and any OAuth proxy endpoints.
1460    //
1461    // Rationale:
1462    // - Body-limit must be outermost on /mcp so RBAC (which reads the
1463    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
1464    // - Auth must run before RBAC because RBAC consumes
1465    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
1466    //   policy.
1467    // - Origin runs before auth so we reject cross-origin requests
1468    //   without spending Argon2 cycles on unauthenticated callers.
1469
1470    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
1471    // Always installed: even when RBAC is disabled, tool rate limiting may
1472    // be active (MCP spec: servers MUST rate limit tool invocations).
1473    {
1474        let tool_limiter: Option<Arc<ToolRateLimiter>> = config
1475            .tool_rate_limit
1476            .map(|per_minute| build_tool_rate_limiter(per_minute, config.tool_rate_limit_burst));
1477
1478        if rbac_swap.load().is_enabled() {
1479            tracing::info!("RBAC enforcement enabled on /mcp");
1480        }
1481        if let Some(limit) = config.tool_rate_limit {
1482            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1483        }
1484
1485        let rbac_for_mw = Arc::clone(&rbac_swap);
1486        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1487            let p = rbac_for_mw.load_full();
1488            let tl = tool_limiter.clone();
1489            rbac_middleware(p, tl, req, next)
1490        }));
1491    }
1492
1493    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
1494    if let Some(ref auth_config) = config.auth
1495        && auth_config.enabled
1496    {
1497        let Some(ref state) = auth_state else {
1498            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1499        };
1500
1501        let methods: Vec<&str> = [
1502            auth_config.mtls.is_some().then_some("mTLS"),
1503            (!auth_config.api_keys.is_empty()).then_some("bearer"),
1504            #[cfg(feature = "oauth")]
1505            auth_config.oauth.is_some().then_some("oauth-jwt"),
1506        ]
1507        .into_iter()
1508        .flatten()
1509        .collect();
1510
1511        tracing::info!(
1512            methods = %methods.join(", "),
1513            api_keys = auth_config.api_keys.len(),
1514            "auth enabled on /mcp"
1515        );
1516
1517        let state_for_mw = Arc::clone(state);
1518        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1519            let s = Arc::clone(&state_for_mw);
1520            auth_middleware(s, req, next)
1521        }));
1522    }
1523
1524    // [3] Request timeout (returns 408 on expiry). Bounds total request
1525    // duration including auth + handler.
1526    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1527        axum::http::StatusCode::REQUEST_TIMEOUT,
1528        config.request_timeout,
1529    ));
1530
1531    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
1532    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
1533    // attempts to buffer or parse the body.
1534    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1535        config.max_request_body,
1536    ));
1537
1538    // Compute the effective allowed-origins list for the outer
1539    // origin-check layer (installed on the merged router below). When
1540    // `allowed_origins` is empty but `public_url` is set, auto-derive
1541    // the origin from the public URL so MCP clients (e.g. Claude Code)
1542    // that send `Origin: <server-url>` are accepted without explicit
1543    // config.
1544    let mut effective_origins = config.allowed_origins.clone();
1545    if effective_origins.is_empty()
1546        && let Some(ref url) = config.public_url
1547    {
1548        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1549        // Strip any path/query from the public URL. Offsets come from
1550        // `find`, so they are char-boundary-aligned; `get(..)` keeps that
1551        // machine-checked (a violation degrades to an empty slice).
1552        if let Some(scheme_end) = url.find("://") {
1553            let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1554            let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1555            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1556            let host = after_scheme.get(..host_end).unwrap_or_default();
1557            let origin = format!("{scheme_with_sep}{host}");
1558            tracing::info!(
1559                %origin,
1560                "auto-derived allowed origin from public_url"
1561            );
1562            effective_origins.push(origin);
1563        }
1564    }
1565    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1566    let cors_origins = Arc::clone(&allowed_origins);
1567    let log_request_headers = config.log_request_headers;
1568
1569    let readyz_route = if let Some(check) = config.readiness_check.take() {
1570        axum::routing::get(move || readyz(Arc::clone(&check)))
1571    } else {
1572        axum::routing::get(healthz)
1573    };
1574
1575    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1576    let mut router = axum::Router::new()
1577        .route("/healthz", axum::routing::get(healthz))
1578        .route("/readyz", readyz_route)
1579        .route(
1580            "/version",
1581            axum::routing::get({
1582                // Pre-serialize the version payload once at router-build
1583                // time. The handler then serves a cheap `Arc::clone` of the
1584                // immutable bytes per request, avoiding `serde_json::Value`
1585                // allocation + serialization on every `/version` hit.
1586                let payload_bytes: Arc<[u8]> =
1587                    serialize_version_payload(&config.name, &config.version);
1588                move || {
1589                    let p = Arc::clone(&payload_bytes);
1590                    async move {
1591                        (
1592                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1593                            p.to_vec(),
1594                        )
1595                    }
1596                }
1597            }),
1598        )
1599        .merge(mcp_router);
1600
1601    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1602    // When configured, wrap them — and only them — in the per-IP rate
1603    // limiter BEFORE merging: axum layers wrap only the routes already
1604    // present on the sub-router, so the limiter can never leak onto
1605    // `/mcp`, health, admin, or OAuth endpoints, while top-level layers
1606    // (origin check, peer-address normalization, ...) still run first.
1607    if let Some(extra) = config.extra_router.take() {
1608        let extra = match config.extra_route_rate_limit {
1609            Some(per_minute) => {
1610                let limiter =
1611                    build_extra_route_rate_limiter(per_minute, config.extra_route_rate_limit_burst);
1612                tracing::info!(per_minute, "extra-route per-IP rate limit enabled");
1613                extra.layer(axum::middleware::from_fn(move |req, next| {
1614                    let l = Arc::clone(&limiter);
1615                    extra_route_rate_limit_middleware(l, req, next)
1616                }))
1617            }
1618            None => extra,
1619        };
1620        router = router.merge(extra);
1621    }
1622
1623    // RFC 9728: Protected Resource Metadata endpoint.
1624    // When OAuth is configured, serve full metadata with authorization_servers.
1625    // Otherwise, serve a minimal document with just the resource URL and no
1626    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1627    // that the server exists but does NOT require OAuth authentication,
1628    // preventing them from gating the connection behind a broken auth flow.
1629    let server_url = if let Some(ref url) = config.public_url {
1630        url.trim_end_matches('/').to_owned()
1631    } else {
1632        let prm_scheme = if config.tls_cert_path.is_some() {
1633            "https"
1634        } else {
1635            "http"
1636        };
1637        format!("{prm_scheme}://{}", config.bind_addr)
1638    };
1639    let resource_url = format!("{server_url}/mcp");
1640
1641    #[cfg(feature = "oauth")]
1642    let prm_metadata = if let Some(ref auth_config) = config.auth
1643        && let Some(ref oauth_config) = auth_config.oauth
1644    {
1645        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1646    } else {
1647        serde_json::json!({ "resource": resource_url })
1648    };
1649    #[cfg(not(feature = "oauth"))]
1650    let prm_metadata = serde_json::json!({ "resource": resource_url });
1651
1652    router = router.route(
1653        "/.well-known/oauth-protected-resource",
1654        axum::routing::get(move || {
1655            let m = prm_metadata.clone();
1656            async move { axum::Json(m) }
1657        }),
1658    );
1659
1660    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1661    // /authorize, /token, /register, and authorization server metadata so
1662    // MCP clients can perform Authorization Code + PKCE against the upstream
1663    // IdP (e.g. Keycloak) transparently.
1664    #[cfg(feature = "oauth")]
1665    if let Some(ref auth_config) = config.auth
1666        && let Some(ref oauth_config) = auth_config.oauth
1667        && oauth_config.proxy.is_some()
1668    {
1669        router =
1670            install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1671    }
1672
1673    // OWASP security response headers (applied to all responses).
1674    // HSTS is conditional on TLS being configured.
1675    let is_tls = config.tls_cert_path.is_some();
1676    let security_headers_cfg = Arc::new(config.security_headers.clone());
1677    router = router.layer(axum::middleware::from_fn(move |req, next| {
1678        let cfg = Arc::clone(&security_headers_cfg);
1679        security_headers_middleware(is_tls, cfg, req, next)
1680    }));
1681
1682    // CORS preflight layer (required for browser-based MCP clients).
1683    // Uses the same effective origins as the origin check middleware
1684    // (including auto-derived origin from public_url).
1685    if !cors_origins.is_empty() {
1686        let cors = tower_http::cors::CorsLayer::new()
1687            .allow_origin(
1688                cors_origins
1689                    .iter()
1690                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1691                    .collect::<Vec<_>>(),
1692            )
1693            .allow_methods([
1694                axum::http::Method::GET,
1695                axum::http::Method::POST,
1696                axum::http::Method::OPTIONS,
1697            ])
1698            .allow_headers([
1699                axum::http::header::CONTENT_TYPE,
1700                axum::http::header::AUTHORIZATION,
1701            ]);
1702        router = router.layer(cors);
1703    }
1704
1705    // Optional response compression (gzip + brotli). Skips small bodies
1706    // to avoid overhead. Applied after CORS so preflight responses remain
1707    // uncompressed.
1708    if config.compression_enabled {
1709        use tower_http::compression::Predicate as _;
1710        let predicate = tower_http::compression::DefaultPredicate::new().and(
1711            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1712        );
1713        router = router.layer(
1714            tower_http::compression::CompressionLayer::new()
1715                .gzip(true)
1716                .br(true)
1717                .compress_when(predicate),
1718        );
1719        tracing::info!(
1720            min_size = config.compression_min_size,
1721            "response compression enabled (gzip, br)"
1722        );
1723    }
1724
1725    // Optional global concurrency cap. `load_shed` converts the
1726    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1727    if let Some(max) = config.max_concurrent_requests {
1728        let overload_handler = tower::ServiceBuilder::new()
1729            .layer(axum::error_handling::HandleErrorLayer::new(
1730                |_err: tower::BoxError| async {
1731                    (
1732                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1733                        axum::Json(serde_json::json!({
1734                            "error": "overloaded",
1735                            "error_description": "server is at capacity, retry later"
1736                        })),
1737                    )
1738                },
1739            ))
1740            .layer(tower::load_shed::LoadShedLayer::new())
1741            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1742        router = router.layer(overload_handler);
1743        tracing::info!(max, "global concurrency limit enabled");
1744    }
1745
1746    // JSON fallback for unmatched routes. Without this, axum returns
1747    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1748    // when they probe OAuth endpoints like /authorize or /token.
1749    router = router.fallback(|| async {
1750        (
1751            axum::http::StatusCode::NOT_FOUND,
1752            axum::Json(serde_json::json!({
1753                "error": "not_found",
1754                "error_description": "The requested endpoint does not exist"
1755            })),
1756        )
1757    });
1758
1759    // Prometheus metrics: recording middleware + separate listener.
1760    #[cfg(feature = "metrics")]
1761    if config.metrics_enabled {
1762        let metrics = Arc::new(
1763            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1764        );
1765        let m = Arc::clone(&metrics);
1766        router = router.layer(axum::middleware::from_fn(
1767            move |req: Request<Body>, next: Next| {
1768                let m = Arc::clone(&m);
1769                metrics_middleware(m, req, next)
1770            },
1771        ));
1772        let metrics_bind = config.metrics_bind.clone();
1773        let metrics_shutdown = ct.clone();
1774        tokio::spawn(async move {
1775            if let Err(e) =
1776                crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1777            {
1778                tracing::error!("metrics listener failed: {e}");
1779            }
1780        });
1781    }
1782
1783    // Peer-address normalization. Mirrors the TLS branch's peer address
1784    // into `ConnectInfo<SocketAddr>` and exposes the framework-owned
1785    // `PeerAddr` extension on both listener branches, so ALL routes on
1786    // the merged router (`/mcp`, `/healthz`, OAuth proxy endpoints,
1787    // admin endpoints, extra_router, ...) and all inner middleware see a
1788    // uniform peer-address contract regardless of TLS. Installed just
1789    // inside the origin check, which stays outermost by design.
1790    let forward_resolver: Option<Arc<ForwardResolver>> = if config.trusted_proxies.is_empty() {
1791        None
1792    } else {
1793        // Entries are guaranteed parseable by `check_trusted_forwarder`;
1794        // filter_map is defensive only.
1795        Some(Arc::new(ForwardResolver {
1796            trusted: config
1797                .trusted_proxies
1798                .iter()
1799                .filter_map(|entry| parse_proxy_net(entry))
1800                .collect(),
1801            mode: config
1802                .forwarded_header
1803                .unwrap_or(ForwardedHeaderMode::XForwardedFor),
1804        }))
1805    };
1806    if forward_resolver.is_some() {
1807        tracing::info!(
1808            proxies = config.trusted_proxies.len(),
1809            "trusted-forwarder mode enabled: limiters key by resolved client IP"
1810        );
1811    }
1812    router = router.layer(axum::middleware::from_fn(move |req, next| {
1813        let r = forward_resolver.clone();
1814        normalize_peer_addr_middleware(r, req, next)
1815    }));
1816
1817    // Origin validation layer (MCP spec: servers MUST validate the
1818    // Origin header to prevent DNS rebinding attacks). Installed as the
1819    // OUTERMOST layer on the OUTER router so it protects ALL routes
1820    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1821    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1822    // reject cross-origin attackers without spending Argon2 cycles.
1823    //
1824    // Origin-less requests (e.g. server-to-server probes, curl, native
1825    // MCP clients) are permitted; only requests with an Origin header
1826    // that does not match `effective_origins` are rejected.
1827    router = router.layer(axum::middleware::from_fn(move |req, next| {
1828        let origins = Arc::clone(&allowed_origins);
1829        origin_check_middleware(origins, log_request_headers, req, next)
1830    }));
1831
1832    let scheme = if config.tls_cert_path.is_some() {
1833        "https"
1834    } else {
1835        "http"
1836    };
1837
1838    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1839        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1840        _ => None,
1841    };
1842    let tls_handshake_timeout = config.tls_handshake_timeout;
1843    let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1844    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1845
1846    Ok((
1847        router,
1848        AppRunParams {
1849            tls_paths,
1850            tls_handshake_timeout,
1851            max_concurrent_tls_handshakes,
1852            mtls_config,
1853            shutdown_timeout: config.shutdown_timeout,
1854            auth_state,
1855            rbac_swap,
1856            on_reload_ready: config.on_reload_ready.take(),
1857            ct,
1858            scheme,
1859            name: config.name.clone(),
1860        },
1861    ))
1862}
1863
1864/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1865/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1866///
1867/// This is the standard entry point for production deployments. For
1868/// deterministic shutdown control (e.g. integration tests), see
1869/// [`serve_with_listener`].
1870///
1871/// The configuration must be validated first via
1872/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1873/// token. This typestate guarantees, at compile time, that the server
1874/// never starts with an invalid configuration.
1875///
1876/// # Errors
1877///
1878/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1879/// fails, or if the underlying axum server returns an error.
1880pub async fn serve<H, F>(
1881    config: Validated<McpServerConfig>,
1882    handler_factory: F,
1883) -> Result<(), McpxError>
1884where
1885    H: ServerHandler + 'static,
1886    F: Fn() -> H + Send + Sync + Clone + 'static,
1887{
1888    let config = config.into_inner();
1889    #[allow(
1890        deprecated,
1891        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1892    )]
1893    let bind_addr = config.bind_addr.clone();
1894    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1895
1896    let listener = TcpListener::bind(&bind_addr)
1897        .await
1898        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1899    log_listening(&params.name, params.scheme, &bind_addr);
1900
1901    run_server(
1902        router,
1903        listener,
1904        params.tls_paths,
1905        params.tls_handshake_timeout,
1906        params.max_concurrent_tls_handshakes,
1907        params.mtls_config,
1908        params.shutdown_timeout,
1909        params.auth_state,
1910        params.rbac_swap,
1911        params.on_reload_ready,
1912        params.ct,
1913    )
1914    .await
1915    .map_err(anyhow_to_startup)
1916}
1917
1918/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1919/// readiness signalling and external shutdown control.
1920///
1921/// This variant is intended for **deterministic integration tests** and
1922/// for embedders that need to bind the listening socket themselves
1923/// (e.g. systemd socket activation). Compared to [`serve`]:
1924///
1925/// * The caller passes a `TcpListener` that is already bound. This
1926///   eliminates the bind race in tests that previously required
1927///   poll-the-`/healthz`-loop start-up detection.
1928/// * `ready_tx`, when `Some`, receives the socket's
1929///   [`SocketAddr`] *after* the router is built and immediately before
1930///   the server starts accepting connections. Tests can `await` the
1931///   matching `oneshot::Receiver` to know exactly when it is safe to
1932///   issue requests.
1933/// * `shutdown`, when `Some`, gives the caller a
1934///   [`CancellationToken`] that triggers the same graceful-shutdown
1935///   path as a real OS signal. This avoids cross-platform issues with
1936///   sending real `SIGTERM` from tests on Windows.
1937///
1938/// All three optional parameters degrade gracefully: if `ready_tx` is
1939/// `None`, no signal is sent; if `shutdown` is `None`, the server only
1940/// stops on an OS signal (just like [`serve`]).
1941///
1942/// # Errors
1943///
1944/// Returns [`McpxError::Startup`] if router construction fails, if reading
1945/// the listener's `local_addr()` fails, or if the underlying axum
1946/// server returns an error.
1947pub async fn serve_with_listener<H, F>(
1948    listener: TcpListener,
1949    config: Validated<McpServerConfig>,
1950    handler_factory: F,
1951    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1952    shutdown: Option<CancellationToken>,
1953) -> Result<(), McpxError>
1954where
1955    H: ServerHandler + 'static,
1956    F: Fn() -> H + Send + Sync + Clone + 'static,
1957{
1958    let config = config.into_inner();
1959    let local_addr = listener
1960        .local_addr()
1961        .map_err(|e| io_to_startup("listener.local_addr", e))?;
1962    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1963
1964    log_listening(&params.name, params.scheme, &local_addr.to_string());
1965
1966    // Forward external shutdown into the server-internal cancellation
1967    // token so `run_server`'s shutdown trigger picks it up alongside
1968    // any real OS signal.
1969    if let Some(external) = shutdown {
1970        let internal = params.ct.clone();
1971        tokio::spawn(async move {
1972            external.cancelled().await;
1973            internal.cancel();
1974        });
1975    }
1976
1977    // Signal readiness *after* the router is fully built and external
1978    // shutdown is wired, but *before* run_server takes ownership of
1979    // the listener. The receiver can immediately issue requests.
1980    if let Some(tx) = ready_tx {
1981        // Receiver may have been dropped (test gave up). That's fine.
1982        let _ = tx.send(local_addr);
1983    }
1984
1985    run_server(
1986        router,
1987        listener,
1988        params.tls_paths,
1989        params.tls_handshake_timeout,
1990        params.max_concurrent_tls_handshakes,
1991        params.mtls_config,
1992        params.shutdown_timeout,
1993        params.auth_state,
1994        params.rbac_swap,
1995        params.on_reload_ready,
1996        params.ct,
1997    )
1998    .await
1999    .map_err(anyhow_to_startup)
2000}
2001
2002/// Emit the standard "listening on …" log lines used by both
2003/// [`serve`] and [`serve_with_listener`].
2004#[allow(
2005    clippy::cognitive_complexity,
2006    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
2007)]
2008fn log_listening(name: &str, scheme: &str, addr: &str) {
2009    tracing::info!("{name} listening on {addr}");
2010    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
2011    tracing::info!("  Health check: {scheme}://{addr}/healthz");
2012    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
2013}
2014
2015/// Drive the chosen axum server variant (TLS or plain) with a graceful
2016/// shutdown window. Consumes the router and listener.
2017///
2018/// # Shutdown semantics
2019///
2020/// A single shutdown trigger (the FIRST of: OS signal via
2021/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
2022///
2023/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
2024///    accepting new connections and waits for in-flight requests to
2025///    drain;
2026/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
2027///    drainage exceeds `shutdown_timeout`.
2028///
2029/// Previously this function awaited `shutdown_signal()` independently
2030/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
2031/// resolves once per future and consumes one signal, the force-exit
2032/// timer was tied to a SECOND signal (a second SIGTERM the operator
2033/// would never send). Under a single SIGTERM the graceful drain could
2034/// hang indefinitely. The current implementation derives both branches
2035/// from a single shared trigger so the timeout race is anchored to the
2036/// FIRST (and only) signal.
2037#[allow(
2038    clippy::too_many_arguments,
2039    clippy::cognitive_complexity,
2040    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
2041)]
2042async fn run_server(
2043    router: axum::Router,
2044    listener: TcpListener,
2045    tls_paths: Option<(PathBuf, PathBuf)>,
2046    tls_handshake_timeout: Duration,
2047    max_concurrent_tls_handshakes: usize,
2048    mtls_config: Option<MtlsConfig>,
2049    shutdown_timeout: Duration,
2050    auth_state: Option<Arc<AuthState>>,
2051    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
2052    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
2053    ct: CancellationToken,
2054) -> anyhow::Result<()> {
2055    // `shutdown_trigger` fires when the FIRST source resolves: either
2056    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
2057    // (which the test harness uses for deterministic shutdown).
2058    let shutdown_trigger = CancellationToken::new();
2059    {
2060        let trigger = shutdown_trigger.clone();
2061        let parent = ct.clone();
2062        tokio::spawn(async move {
2063            tokio::select! {
2064                () = shutdown_signal() => {}
2065                () = parent.cancelled() => {}
2066            }
2067            trigger.cancel();
2068        });
2069    }
2070
2071    let graceful = {
2072        let trigger = shutdown_trigger.clone();
2073        let ct = ct.clone();
2074        async move {
2075            trigger.cancelled().await;
2076            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
2077            ct.cancel();
2078        }
2079    };
2080
2081    let force_exit_timer = {
2082        let trigger = shutdown_trigger.clone();
2083        async move {
2084            trigger.cancelled().await;
2085            tokio::time::sleep(shutdown_timeout).await;
2086        }
2087    };
2088
2089    if let Some((cert_path, key_path)) = tls_paths {
2090        let crl_set = if let Some(mtls) = mtls_config.as_ref()
2091            && mtls.crl_enabled
2092        {
2093            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
2094            let (crl_set, discover_rx) =
2095                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
2096                    .await
2097                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
2098            tokio::spawn(mtls_revocation::run_crl_refresher(
2099                Arc::clone(&crl_set),
2100                discover_rx,
2101                ct.clone(),
2102            ));
2103            Some(crl_set)
2104        } else {
2105            None
2106        };
2107
2108        if let Some(cb) = on_reload_ready.take() {
2109            cb(ReloadHandle {
2110                auth: auth_state.clone(),
2111                rbac: Some(Arc::clone(&rbac_swap)),
2112                crl_set: crl_set.clone(),
2113            });
2114        }
2115
2116        let tls_listener = TlsListener::new(
2117            listener,
2118            &cert_path,
2119            &key_path,
2120            mtls_config.as_ref(),
2121            crl_set,
2122            tls_handshake_timeout,
2123            max_concurrent_tls_handshakes,
2124        )?;
2125        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
2126        tokio::select! {
2127            result = axum::serve(tls_listener, make_svc)
2128                .with_graceful_shutdown(graceful) => { result?; }
2129            () = force_exit_timer => {
2130                tracing::warn!("shutdown timeout exceeded, forcing exit");
2131            }
2132        }
2133    } else {
2134        if let Some(cb) = on_reload_ready.take() {
2135            cb(ReloadHandle {
2136                auth: auth_state,
2137                rbac: Some(rbac_swap),
2138                crl_set: None,
2139            });
2140        }
2141
2142        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
2143        tokio::select! {
2144            result = axum::serve(listener, make_svc)
2145                .with_graceful_shutdown(graceful) => { result?; }
2146            () = force_exit_timer => {
2147                tracing::warn!("shutdown timeout exceeded, forcing exit");
2148            }
2149        }
2150    }
2151
2152    Ok(())
2153}
2154
2155/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
2156/// `/register`, and authorization server metadata) on `router`. The
2157/// caller must ensure `oauth_config.proxy` is `Some`.
2158///
2159/// # Errors
2160///
2161/// Returns [`McpxError::Startup`] if the shared
2162/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
2163#[cfg(feature = "oauth")]
2164fn install_oauth_proxy_routes(
2165    router: axum::Router,
2166    server_url: &str,
2167    oauth_config: &crate::oauth::OAuthConfig,
2168    auth_state: Option<&Arc<AuthState>>,
2169) -> Result<axum::Router, McpxError> {
2170    let Some(ref proxy) = oauth_config.proxy else {
2171        return Ok(router);
2172    };
2173
2174    // Single shared HTTP client for all proxy endpoints. Cloning is
2175    // cheap (refcounted) and shares the underlying connection pool.
2176    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
2177
2178    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
2179    let router = router.route(
2180        "/.well-known/oauth-authorization-server",
2181        axum::routing::get(move || {
2182            let m = asm.clone();
2183            async move { axum::Json(m) }
2184        }),
2185    );
2186
2187    let proxy_authorize = proxy.clone();
2188    let router = router.route(
2189        "/authorize",
2190        axum::routing::get(
2191            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
2192                let p = proxy_authorize.clone();
2193                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
2194            },
2195        ),
2196    );
2197
2198    let proxy_token = proxy.clone();
2199    let token_http = http.clone();
2200    let router = router.route(
2201        "/token",
2202        axum::routing::post(move |body: String| {
2203            let p = proxy_token.clone();
2204            let h = token_http.clone();
2205            async move { crate::oauth::handle_token(&h, &p, &body).await }
2206        })
2207        .layer(axum::middleware::from_fn(
2208            oauth_token_cache_headers_middleware,
2209        )),
2210    );
2211
2212    let proxy_register = proxy.clone();
2213    let router = router.route(
2214        "/register",
2215        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
2216            let p = proxy_register;
2217            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
2218        })
2219        .layer(axum::middleware::from_fn(
2220            oauth_token_cache_headers_middleware,
2221        )),
2222    );
2223
2224    let admin_routes_enabled = proxy.expose_admin_endpoints
2225        && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
2226    if proxy.expose_admin_endpoints
2227        && !proxy.require_auth_on_admin_endpoints
2228        && proxy.allow_unauthenticated_admin_endpoints
2229    {
2230        // M3 escape-hatch in effect: validate() let this through because
2231        // the operator explicitly opted in. Surface it loudly at startup
2232        // so the choice is auditable in logs.
2233        tracing::warn!(
2234            "OAuth introspect/revoke endpoints are unauthenticated by explicit \
2235             allow_unauthenticated_admin_endpoints opt-out; ensure an \
2236             authenticated reverse proxy fronts these routes"
2237        );
2238    }
2239
2240    let admin_router = if admin_routes_enabled {
2241        build_oauth_admin_router(proxy, http, auth_state)?
2242    } else {
2243        axum::Router::new()
2244    };
2245
2246    let router = router.merge(admin_router);
2247
2248    tracing::info!(
2249        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
2250        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
2251        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
2252    );
2253    Ok(router)
2254}
2255
2256/// Build the optional `/introspect` + `/revoke` admin sub-router.
2257///
2258/// Layered with [`oauth_token_cache_headers_middleware`] so RFC 6749 §5.1
2259/// / RFC 6750 §5.4 cache headers are emitted, and conditionally with the
2260/// auth middleware when `proxy.require_auth_on_admin_endpoints` is set.
2261#[cfg(feature = "oauth")]
2262fn build_oauth_admin_router(
2263    proxy: &crate::oauth::OAuthProxyConfig,
2264    http: crate::oauth::OauthHttpClient,
2265    auth_state: Option<&Arc<AuthState>>,
2266) -> Result<axum::Router, McpxError> {
2267    let mut admin_router = axum::Router::new();
2268    if proxy.introspection_url.is_some() {
2269        let proxy_introspect = proxy.clone();
2270        let introspect_http = http.clone();
2271        admin_router = admin_router.route(
2272            "/introspect",
2273            axum::routing::post(move |body: String| {
2274                let p = proxy_introspect.clone();
2275                let h = introspect_http.clone();
2276                async move { crate::oauth::handle_introspect(&h, &p, &body).await }
2277            }),
2278        );
2279    }
2280    if proxy.revocation_url.is_some() {
2281        let proxy_revoke = proxy.clone();
2282        let revoke_http = http;
2283        admin_router = admin_router.route(
2284            "/revoke",
2285            axum::routing::post(move |body: String| {
2286                let p = proxy_revoke.clone();
2287                let h = revoke_http.clone();
2288                async move { crate::oauth::handle_revoke(&h, &p, &body).await }
2289            }),
2290        );
2291    }
2292
2293    let admin_router = admin_router.layer(axum::middleware::from_fn(
2294        oauth_token_cache_headers_middleware,
2295    ));
2296
2297    if proxy.require_auth_on_admin_endpoints {
2298        let Some(state) = auth_state else {
2299            return Err(McpxError::Startup(
2300                "oauth proxy admin endpoints require auth state".into(),
2301            ));
2302        };
2303        let state_for_mw = Arc::clone(state);
2304        Ok(
2305            admin_router.layer(axum::middleware::from_fn(move |req, next| {
2306                let s = Arc::clone(&state_for_mw);
2307                auth_middleware(s, req, next)
2308            })),
2309        )
2310    } else {
2311        Ok(admin_router)
2312    }
2313}
2314
2315/// Build the host allow-list for rmcp's DNS rebinding protection.
2316///
2317/// Includes loopback hosts by default, then augments with host/authority
2318/// derived from `public_url` and the server bind address.
2319fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
2320    let mut hosts = vec![
2321        "localhost".to_owned(),
2322        "127.0.0.1".to_owned(),
2323        "::1".to_owned(),
2324    ];
2325
2326    if let Some(url) = public_url
2327        && let Ok(uri) = url.parse::<axum::http::Uri>()
2328        && let Some(authority) = uri.authority()
2329    {
2330        let host = authority.host().to_owned();
2331        if !hosts.iter().any(|h| h == &host) {
2332            hosts.push(host);
2333        }
2334
2335        let authority = authority.as_str().to_owned();
2336        if !hosts.iter().any(|h| h == &authority) {
2337            hosts.push(authority);
2338        }
2339    }
2340
2341    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
2342        && let Some(authority) = uri.authority()
2343    {
2344        let host = authority.host().to_owned();
2345        if !hosts.iter().any(|h| h == &host) {
2346            hosts.push(host);
2347        }
2348
2349        let authority = authority.as_str().to_owned();
2350        if !hosts.iter().any(|h| h == &authority) {
2351            hosts.push(authority);
2352        }
2353    }
2354
2355    hosts
2356}
2357
2358// - TLS support -
2359
2360/// Implement axum's `Connected` trait for `TlsConnInfo` so that
2361/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
2362/// over our custom `TlsListener`.
2363///
2364/// The identity is read directly from the wrapping
2365/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
2366/// between the TLS connection and its mTLS identity. This eliminates the
2367/// previous shared-map approach which was vulnerable to ephemeral-port
2368/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
2369/// pair could alias a stale entry).
2370impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
2371    for TlsConnInfo
2372{
2373    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
2374        let addr = *target.remote_addr();
2375        let identity = target.io().identity().cloned();
2376        TlsConnInfo::new(addr, identity)
2377    }
2378}
2379
2380/// Default per-handshake deadline on the TLS accept path. Prevents idle
2381/// or slow-loris connections from pinning handshake worker tasks (and
2382/// their semaphore permits) indefinitely.
2383///
2384/// Configurable since 1.9.0 via
2385/// [`McpServerConfig::with_tls_handshake_timeout`].
2386const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
2387
2388/// Default upper bound on concurrently in-flight TLS handshakes. When
2389/// saturated, the acceptor task stops pulling new connections from the
2390/// kernel backlog (backpressure) instead of accepting and dropping them
2391/// in user space.
2392///
2393/// Configurable since 1.9.0 via
2394/// [`McpServerConfig::with_max_concurrent_tls_handshakes`].
2395const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
2396
2397/// Capacity of the completed-handshake queue between the acceptor task and
2398/// `axum::serve`'s `accept()` loop. Handshake workers block on `send` when
2399/// the queue is full, so a slow accept loop back-pressures handshakes
2400/// rather than buffering completed connections unboundedly.
2401const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
2402
2403/// A TLS-wrapping listener that implements axum's `Listener` trait.
2404///
2405/// TCP accepts and TLS handshakes run on a dedicated background task: each
2406/// accepted connection's handshake is spawned onto its own worker task,
2407/// bounded by a configurable concurrent-handshake cap (default
2408/// [`DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES`]) and a per-handshake timeout
2409/// (default [`DEFAULT_TLS_HANDSHAKE_TIMEOUT`]). A slow or idle client
2410/// therefore cannot stall other connections behind a serialized inline
2411/// handshake.
2412///
2413/// When mTLS is configured, client certificates are verified against the
2414/// configured CA and the client identity is extracted at handshake time.
2415/// The extracted identity is bound to the connection itself via the
2416/// returned [`AuthenticatedTlsStream`], so it is impossible for an
2417/// unrelated connection to observe it.
2418struct TlsListener {
2419    /// Bound address, captured eagerly before the `TcpListener` moves into
2420    /// the acceptor task.
2421    local_addr: SocketAddr,
2422    /// Completed handshakes produced by the acceptor task's workers.
2423    rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2424    /// Background task driving TCP accepts and concurrent TLS handshakes.
2425    /// Aborted on drop so the listener releases its port deterministically.
2426    acceptor_task: tokio::task::JoinHandle<()>,
2427}
2428
2429impl TlsListener {
2430    fn new(
2431        inner: TcpListener,
2432        cert_path: &Path,
2433        key_path: &Path,
2434        mtls_config: Option<&MtlsConfig>,
2435        crl_set: Option<Arc<CrlSet>>,
2436        handshake_timeout: Duration,
2437        max_concurrent_handshakes: usize,
2438    ) -> anyhow::Result<Self> {
2439        // Install the ring crypto provider (ok to call multiple times).
2440        rustls::crypto::ring::default_provider()
2441            .install_default()
2442            .ok();
2443
2444        let certs = load_certs(cert_path)?;
2445        let key = load_key(key_path)?;
2446
2447        let mtls_default_role;
2448
2449        let tls_config = if let Some(mtls) = mtls_config {
2450            mtls_default_role = mtls.default_role.clone();
2451            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2452            {
2453                let Some(crl_set) = crl_set else {
2454                    return Err(anyhow::anyhow!(
2455                        "mTLS CRL verifier requested but CRL state was not initialized"
2456                    ));
2457                };
2458                Arc::new(DynamicClientCertVerifier::new(crl_set))
2459            } else {
2460                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2461                if mtls.required {
2462                    rustls::server::WebPkiClientVerifier::builder(root_store)
2463                        .build()
2464                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2465                } else {
2466                    rustls::server::WebPkiClientVerifier::builder(root_store)
2467                        .allow_unauthenticated()
2468                        .build()
2469                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2470                }
2471            };
2472
2473            tracing::info!(
2474                ca = %mtls.ca_cert_path.display(),
2475                required = mtls.required,
2476                crl_enabled = mtls.crl_enabled,
2477                "mTLS client auth configured"
2478            );
2479
2480            rustls::ServerConfig::builder_with_protocol_versions(&[
2481                &rustls::version::TLS12,
2482                &rustls::version::TLS13,
2483            ])
2484            .with_client_cert_verifier(verifier)
2485            .with_single_cert(certs, key)?
2486        } else {
2487            mtls_default_role = "viewer".to_owned();
2488            rustls::ServerConfig::builder_with_protocol_versions(&[
2489                &rustls::version::TLS12,
2490                &rustls::version::TLS13,
2491            ])
2492            .with_no_client_auth()
2493            .with_single_cert(certs, key)?
2494        };
2495
2496        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2497        tracing::info!(
2498            "TLS enabled (cert: {}, key: {})",
2499            cert_path.display(),
2500            key_path.display()
2501        );
2502        let local_addr = inner.local_addr()?;
2503        let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2504        let acceptor_task = tokio::spawn(run_tls_acceptor(
2505            inner,
2506            acceptor,
2507            mtls_default_role,
2508            tx,
2509            handshake_timeout,
2510            max_concurrent_handshakes,
2511        ));
2512        Ok(Self {
2513            local_addr,
2514            rx,
2515            acceptor_task,
2516        })
2517    }
2518
2519    /// Extract the mTLS client cert identity from a completed TLS handshake.
2520    /// Returns `None` if no client certificate was presented or if the
2521    /// certificate could not be parsed into an [`AuthIdentity`].
2522    fn extract_handshake_identity(
2523        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2524        default_role: &str,
2525        addr: SocketAddr,
2526    ) -> Option<AuthIdentity> {
2527        let (_, server_conn) = tls_stream.get_ref();
2528        let cert_der = server_conn.peer_certificates()?.first()?;
2529        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2530        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2531        Some(id)
2532    }
2533}
2534
2535/// Drive TCP accepts and concurrent TLS handshakes for [`TlsListener`].
2536///
2537/// Each accepted connection's handshake runs on its own worker task under
2538/// a permit from a `max_concurrent_handshakes`-sized semaphore and a
2539/// `handshake_timeout` deadline. Completed handshakes are pushed to `tx`;
2540/// failures and timeouts are logged at DEBUG and the connection dropped.
2541/// The loop exits when the owning [`TlsListener`] is dropped.
2542async fn run_tls_acceptor(
2543    listener: TcpListener,
2544    acceptor: tokio_rustls::TlsAcceptor,
2545    default_role: String,
2546    tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2547    handshake_timeout: Duration,
2548    max_concurrent_handshakes: usize,
2549) {
2550    let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2551    loop {
2552        // Acquire the permit BEFORE accepting: at saturation, pending
2553        // connections wait in the kernel backlog instead of being accepted
2554        // and then buffered or dropped in user space.
2555        let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2556            // The semaphore is never closed; defensive exit.
2557            return;
2558        };
2559        let (stream, addr) = match listener.accept().await {
2560            Ok(pair) => pair,
2561            Err(e) => {
2562                tracing::debug!("TCP accept error: {e}");
2563                continue;
2564            }
2565        };
2566        if tx.is_closed() {
2567            // The listener was dropped (shutdown): stop accepting.
2568            return;
2569        }
2570        let acceptor = acceptor.clone();
2571        let default_role = default_role.clone();
2572        let tx = tx.clone();
2573        tokio::spawn(async move {
2574            let _permit = permit;
2575            match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2576                Ok(Ok(tls_stream)) => {
2577                    let identity =
2578                        TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2579                    let wrapped = AuthenticatedTlsStream {
2580                        inner: tls_stream,
2581                        identity,
2582                    };
2583                    // The receiver only disappears during shutdown; discard
2584                    // the completed connection quietly rather than logging.
2585                    let _ = tx.send((wrapped, addr)).await;
2586                }
2587                Ok(Err(e)) => {
2588                    tracing::debug!("TLS handshake failed from {addr}: {e}");
2589                }
2590                Err(_elapsed) => {
2591                    tracing::debug!(
2592                        "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2593                    );
2594                }
2595            }
2596        });
2597    }
2598}
2599
2600/// A TLS stream paired with the mTLS identity extracted at handshake time.
2601///
2602/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
2603/// identity travels with the connection itself. This replaces the previous
2604/// shared `MtlsIdentities` map, eliminating the
2605/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
2606/// reuse and removing the need for an LRU eviction policy.
2607///
2608/// The wrapper is `Unpin` (its inner stream is `Unpin` because
2609/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
2610/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
2611pub(crate) struct AuthenticatedTlsStream {
2612    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2613    identity: Option<AuthIdentity>,
2614}
2615
2616impl AuthenticatedTlsStream {
2617    /// Returns the verified mTLS client identity, if any.
2618    #[must_use]
2619    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2620        self.identity.as_ref()
2621    }
2622}
2623
2624impl std::fmt::Debug for AuthenticatedTlsStream {
2625    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2626        f.debug_struct("AuthenticatedTlsStream")
2627            .field("identity", &self.identity.as_ref().map(|id| &id.name))
2628            .finish_non_exhaustive()
2629    }
2630}
2631
2632impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2633    fn poll_read(
2634        mut self: Pin<&mut Self>,
2635        cx: &mut std::task::Context<'_>,
2636        buf: &mut tokio::io::ReadBuf<'_>,
2637    ) -> std::task::Poll<std::io::Result<()>> {
2638        Pin::new(&mut self.inner).poll_read(cx, buf)
2639    }
2640}
2641
2642impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2643    fn poll_write(
2644        mut self: Pin<&mut Self>,
2645        cx: &mut std::task::Context<'_>,
2646        buf: &[u8],
2647    ) -> std::task::Poll<std::io::Result<usize>> {
2648        Pin::new(&mut self.inner).poll_write(cx, buf)
2649    }
2650
2651    fn poll_flush(
2652        mut self: Pin<&mut Self>,
2653        cx: &mut std::task::Context<'_>,
2654    ) -> std::task::Poll<std::io::Result<()>> {
2655        Pin::new(&mut self.inner).poll_flush(cx)
2656    }
2657
2658    fn poll_shutdown(
2659        mut self: Pin<&mut Self>,
2660        cx: &mut std::task::Context<'_>,
2661    ) -> std::task::Poll<std::io::Result<()>> {
2662        Pin::new(&mut self.inner).poll_shutdown(cx)
2663    }
2664
2665    fn poll_write_vectored(
2666        mut self: Pin<&mut Self>,
2667        cx: &mut std::task::Context<'_>,
2668        bufs: &[std::io::IoSlice<'_>],
2669    ) -> std::task::Poll<std::io::Result<usize>> {
2670        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2671    }
2672
2673    fn is_write_vectored(&self) -> bool {
2674        self.inner.is_write_vectored()
2675    }
2676}
2677
2678impl axum::serve::Listener for TlsListener {
2679    type Io = AuthenticatedTlsStream;
2680    type Addr = SocketAddr;
2681
2682    /// Yield the next fully-handshaken TLS connection.
2683    ///
2684    /// Cancel-safe: this is a plain `mpsc::Receiver::recv`, so cancelling
2685    /// the future (axum selects it against graceful shutdown) never loses
2686    /// a connection.
2687    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2688        if let Some(pair) = self.rx.recv().await {
2689            return pair;
2690        }
2691        // The channel only closes if the acceptor task terminated, which
2692        // means the TcpListener is gone and the OS already refuses new
2693        // connections. `Listener::accept` is infallible and panicking is
2694        // forbidden, so park forever: existing connections keep being
2695        // served and graceful shutdown still completes.
2696        tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2697        std::future::pending().await
2698    }
2699
2700    fn local_addr(&self) -> std::io::Result<Self::Addr> {
2701        Ok(self.local_addr)
2702    }
2703}
2704
2705impl Drop for TlsListener {
2706    fn drop(&mut self) {
2707        // Stop accepting immediately and release the bound port. In-flight
2708        // handshake workers notice the closed channel and exit quietly.
2709        self.acceptor_task.abort();
2710    }
2711}
2712
2713fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2714    use rustls::pki_types::pem::PemObject;
2715    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2716        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2717        .collect::<Result<_, _>>()
2718        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2719    anyhow::ensure!(
2720        !certs.is_empty(),
2721        "no certificates found in {}",
2722        path.display()
2723    );
2724    Ok(certs)
2725}
2726
2727fn load_client_auth_roots(
2728    path: &Path,
2729) -> anyhow::Result<(
2730    Vec<rustls::pki_types::CertificateDer<'static>>,
2731    Arc<RootCertStore>,
2732)> {
2733    let ca_certs = load_certs(path)?;
2734    let mut root_store = RootCertStore::empty();
2735    for cert in &ca_certs {
2736        root_store
2737            .add(cert.clone())
2738            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2739    }
2740
2741    Ok((ca_certs, Arc::new(root_store)))
2742}
2743
2744fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2745    use rustls::pki_types::pem::PemObject;
2746    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2747        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2748}
2749
2750#[allow(
2751    clippy::unused_async,
2752    reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2753)]
2754async fn healthz() -> impl IntoResponse {
2755    axum::Json(serde_json::json!({
2756        "status": "ok",
2757    }))
2758}
2759
2760/// Build the `/version` JSON payload for a given server name and version.
2761///
2762/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
2763/// read at compile time from the `RMCP_SERVER_KIT_BUILD_SHA`,
2764/// `RMCP_SERVER_KIT_BUILD_TIME`, and `RMCP_SERVER_KIT_RUSTC_VERSION` env
2765/// vars. Unset values resolve to `"unknown"`.
2766fn version_payload(name: &str, version: &str) -> serde_json::Value {
2767    serde_json::json!({
2768        "name": name,
2769        "version": version,
2770        "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2771        "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2772        "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2773        "mcpx_version": env!("CARGO_PKG_VERSION"),
2774    })
2775}
2776
2777/// Pre-serialize the `/version` payload to immutable bytes.
2778///
2779/// This is called once at router-build time so per-request handling can
2780/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
2781/// [`serde_json::Value`] on every hit.
2782///
2783/// Serialization of a flat `serde_json::Value` of static-string fields
2784/// cannot fail in practice; the fallback to `b"{}"` exists only to
2785/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
2786fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2787    let value = version_payload(name, version);
2788    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2789}
2790
2791async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2792    let status = check().await;
2793    let ready = status
2794        .get("ready")
2795        .and_then(serde_json::Value::as_bool)
2796        .unwrap_or(false);
2797    let code = if ready {
2798        axum::http::StatusCode::OK
2799    } else {
2800        axum::http::StatusCode::SERVICE_UNAVAILABLE
2801    };
2802    (code, axum::Json(status))
2803}
2804
2805/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2806///
2807/// On non-Unix platforms, only SIGINT is handled.
2808async fn shutdown_signal() {
2809    let ctrl_c = tokio::signal::ctrl_c();
2810
2811    #[cfg(unix)]
2812    {
2813        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2814            Ok(mut term) => {
2815                tokio::select! {
2816                    _ = ctrl_c => {}
2817                    _ = term.recv() => {}
2818                }
2819            }
2820            Err(e) => {
2821                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2822                ctrl_c.await.ok();
2823            }
2824        }
2825    }
2826
2827    #[cfg(not(unix))]
2828    {
2829        ctrl_c.await.ok();
2830    }
2831}
2832
2833// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2834
2835/// Middleware that validates the `Origin` header on incoming HTTP requests.
2836///
2837/// Record HTTP request metrics (method, path, status, duration).
2838#[cfg(feature = "metrics")]
2839async fn metrics_middleware(
2840    metrics: Arc<crate::metrics::McpMetrics>,
2841    req: Request<Body>,
2842    next: Next,
2843) -> axum::response::Response {
2844    let method = req.method().to_string();
2845    let path = req.uri().path().to_owned();
2846    let start = std::time::Instant::now();
2847
2848    let response = next.run(req).await;
2849
2850    let status = response.status().as_u16().to_string();
2851    let duration = start.elapsed().as_secs_f64();
2852
2853    metrics
2854        .http_requests_total
2855        .with_label_values(&[&method, &path, &status])
2856        .inc();
2857    metrics
2858        .http_request_duration_seconds
2859        .with_label_values(&[&method, &path])
2860        .observe(duration);
2861
2862    response
2863}
2864
2865/// OWASP security header hardening applied to every response.
2866///
2867/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2868/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2869/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2870/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2871/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2872///
2873/// Each header's value can be customised via [`SecurityHeadersConfig`]
2874/// on [`McpServerConfig`]. See that type for the three-state semantic
2875/// (`None` = default, `Some("")` = omit, `Some(v)` = override).
2876async fn security_headers_middleware(
2877    is_tls: bool,
2878    cfg: Arc<SecurityHeadersConfig>,
2879    req: Request<Body>,
2880    next: Next,
2881) -> axum::response::Response {
2882    use axum::http::{HeaderName, header};
2883
2884    let mut resp = next.run(req).await;
2885    let headers = resp.headers_mut();
2886
2887    // Strip server identity headers to reduce information leakage.
2888    headers.remove(header::SERVER);
2889    headers.remove(HeaderName::from_static("x-powered-by"));
2890
2891    apply_security_header(
2892        headers,
2893        header::X_CONTENT_TYPE_OPTIONS,
2894        cfg.x_content_type_options.as_deref(),
2895        "nosniff",
2896    );
2897    apply_security_header(
2898        headers,
2899        header::X_FRAME_OPTIONS,
2900        cfg.x_frame_options.as_deref(),
2901        "deny",
2902    );
2903    apply_security_header(
2904        headers,
2905        header::CACHE_CONTROL,
2906        cfg.cache_control.as_deref(),
2907        "no-store, max-age=0",
2908    );
2909    apply_security_header(
2910        headers,
2911        header::REFERRER_POLICY,
2912        cfg.referrer_policy.as_deref(),
2913        "no-referrer",
2914    );
2915    apply_security_header(
2916        headers,
2917        HeaderName::from_static("cross-origin-opener-policy"),
2918        cfg.cross_origin_opener_policy.as_deref(),
2919        "same-origin",
2920    );
2921    apply_security_header(
2922        headers,
2923        HeaderName::from_static("cross-origin-resource-policy"),
2924        cfg.cross_origin_resource_policy.as_deref(),
2925        "same-origin",
2926    );
2927    apply_security_header(
2928        headers,
2929        HeaderName::from_static("cross-origin-embedder-policy"),
2930        cfg.cross_origin_embedder_policy.as_deref(),
2931        "require-corp",
2932    );
2933    apply_security_header(
2934        headers,
2935        HeaderName::from_static("permissions-policy"),
2936        cfg.permissions_policy.as_deref(),
2937        "accelerometer=(), camera=(), geolocation=(), microphone=()",
2938    );
2939    apply_security_header(
2940        headers,
2941        HeaderName::from_static("x-permitted-cross-domain-policies"),
2942        cfg.x_permitted_cross_domain_policies.as_deref(),
2943        "none",
2944    );
2945    apply_security_header(
2946        headers,
2947        HeaderName::from_static("content-security-policy"),
2948        cfg.content_security_policy.as_deref(),
2949        "default-src 'none'; frame-ancestors 'none'",
2950    );
2951    apply_security_header(
2952        headers,
2953        HeaderName::from_static("x-dns-prefetch-control"),
2954        cfg.x_dns_prefetch_control.as_deref(),
2955        "off",
2956    );
2957
2958    if is_tls {
2959        apply_security_header(
2960            headers,
2961            header::STRICT_TRANSPORT_SECURITY,
2962            cfg.strict_transport_security.as_deref(),
2963            "max-age=63072000; includeSubDomains",
2964        );
2965    }
2966
2967    resp
2968}
2969
2970/// Set a single security header on the response, honouring the
2971/// three-state override semantic (None = default, Some("") = omit,
2972/// Some(value) = override).
2973///
2974/// Defence-in-depth: if an override value somehow reaches this point
2975/// despite [`validate_security_headers`] having approved it (e.g. a
2976/// runtime mutation on a non-`Validated` field), we log at error level
2977/// and fall back to the static default rather than panicking. The
2978/// `Validated<McpServerConfig>` type makes that path unreachable in
2979/// well-typed code paths.
2980fn apply_security_header(
2981    headers: &mut axum::http::HeaderMap,
2982    name: axum::http::HeaderName,
2983    override_value: Option<&str>,
2984    default: &'static str,
2985) {
2986    use axum::http::HeaderValue;
2987
2988    match override_value {
2989        None => {
2990            headers.insert(name, HeaderValue::from_static(default));
2991        }
2992        Some("") => {
2993            // Operator explicitly opted out of this header.
2994        }
2995        Some(v) => match HeaderValue::from_str(v) {
2996            Ok(hv) => {
2997                headers.insert(name, hv);
2998            }
2999            Err(err) => {
3000                tracing::error!(
3001                    header = %name,
3002                    error = %err,
3003                    "invalid security header override reached middleware; using default"
3004                );
3005                headers.insert(name, HeaderValue::from_static(default));
3006            }
3007        },
3008    }
3009}
3010
3011/// Validate every non-empty entry in a [`SecurityHeadersConfig`].
3012///
3013/// - `None` and `Some("")` are accepted unconditionally (use-default and
3014///   omit, respectively).
3015/// - `Some(v)` is rejected if `axum::http::HeaderValue::from_str(v)` fails.
3016/// - `strict_transport_security` additionally rejects any value
3017///   containing `preload` (case-insensitive). Operators who genuinely
3018///   want to commit to the HSTS preload list must do so via a future
3019///   explicit `with_hsts_preload(true)` builder, not by smuggling
3020///   `preload` through this knob.
3021fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
3022    use axum::http::HeaderValue;
3023
3024    let fields: &[(&str, Option<&str>)] = &[
3025        (
3026            "x_content_type_options",
3027            cfg.x_content_type_options.as_deref(),
3028        ),
3029        ("x_frame_options", cfg.x_frame_options.as_deref()),
3030        ("cache_control", cfg.cache_control.as_deref()),
3031        ("referrer_policy", cfg.referrer_policy.as_deref()),
3032        (
3033            "cross_origin_opener_policy",
3034            cfg.cross_origin_opener_policy.as_deref(),
3035        ),
3036        (
3037            "cross_origin_resource_policy",
3038            cfg.cross_origin_resource_policy.as_deref(),
3039        ),
3040        (
3041            "cross_origin_embedder_policy",
3042            cfg.cross_origin_embedder_policy.as_deref(),
3043        ),
3044        ("permissions_policy", cfg.permissions_policy.as_deref()),
3045        (
3046            "x_permitted_cross_domain_policies",
3047            cfg.x_permitted_cross_domain_policies.as_deref(),
3048        ),
3049        (
3050            "content_security_policy",
3051            cfg.content_security_policy.as_deref(),
3052        ),
3053        (
3054            "x_dns_prefetch_control",
3055            cfg.x_dns_prefetch_control.as_deref(),
3056        ),
3057        (
3058            "strict_transport_security",
3059            cfg.strict_transport_security.as_deref(),
3060        ),
3061    ];
3062
3063    for (field, value) in fields {
3064        let Some(v) = value else { continue };
3065        if v.is_empty() {
3066            continue;
3067        }
3068        if let Err(err) = HeaderValue::from_str(v) {
3069            return Err(McpxError::Config(format!(
3070                "invalid security_headers.{field}: {err}"
3071            )));
3072        }
3073    }
3074
3075    if let Some(v) = cfg.strict_transport_security.as_deref()
3076        && !v.is_empty()
3077        && v.to_ascii_lowercase().contains("preload")
3078    {
3079        return Err(McpxError::Config(format!(
3080            "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
3081             HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
3082        )));
3083    }
3084
3085    Ok(())
3086}
3087
3088/// Append RFC 6749 §5.1 / RFC 6750 §5.4 cache and `Vary` headers required
3089/// on OAuth token-issuing responses.
3090///
3091/// `Cache-Control: no-store, max-age=0` is already applied globally by
3092/// [`security_headers_middleware`]; this middleware adds:
3093///
3094/// - `Pragma: no-cache` -- mandated by RFC 6749 §5.1 for HTTP/1.0 caches.
3095/// - `Vary: Authorization` -- mandated by RFC 6750 §5.4 for endpoints
3096///   whose response depends on the `Authorization` header.
3097///
3098/// Applied only to the OAuth proxy token-class endpoints (`/token`,
3099/// `/register`, `/introspect`, `/revoke`). `Vary` is appended (not
3100/// inserted) so any `Vary` value already present (e.g. `Accept-Encoding`
3101/// from a compression layer, or `Origin` from a CORS layer) is preserved.
3102#[cfg(feature = "oauth")]
3103async fn oauth_token_cache_headers_middleware(
3104    req: Request<Body>,
3105    next: Next,
3106) -> axum::response::Response {
3107    use axum::http::{HeaderValue, header};
3108
3109    let mut resp = next.run(req).await;
3110    let headers = resp.headers_mut();
3111    headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
3112    headers.append(header::VARY, HeaderValue::from_static("Authorization"));
3113    resp
3114}
3115
3116/// Normalize peer-address request extensions across listener branches.
3117///
3118/// The make-service installs `ConnectInfo<SocketAddr>` on the plain
3119/// listener but `ConnectInfo<TlsConnInfo>` on the TLS listener (the
3120/// latter additionally carries the connection-bound mTLS identity and
3121/// stays `pub(crate)` — see the anti-aliasing rationale on
3122/// [`TlsConnInfo`]). Application routes — in particular those merged via
3123/// [`McpServerConfig::with_extra_router`], which bypass the auth
3124/// middleware and its private fallback — could therefore not read the
3125/// peer address under TLS.
3126///
3127/// This middleware makes both branches look identical to every route and
3128/// inner middleware:
3129///
3130/// 1. mirrors the TLS peer address into `ConnectInfo<SocketAddr>` when
3131///    (and only when) it is absent, so stock axum-ecosystem extractors
3132///    work unmodified,
3133/// 2. inserts the framework-owned [`PeerAddr`] extension on both
3134///    branches, and
3135/// 3. inserts the resolved [`ClientIp`] extension: the direct peer's IP,
3136///    unless trusted-forwarder mode is configured AND the direct peer is
3137///    a trusted proxy AND the forwarding chain resolves — every
3138///    ambiguous chain falls back to the direct peer with only a reason
3139///    code logged at `debug` (never raw header contents).
3140///
3141/// Precedence mirrors the auth middleware: an existing
3142/// `ConnectInfo<SocketAddr>` always wins and is never overwritten. The
3143/// peer address is deliberately not logged here.
3144async fn normalize_peer_addr_middleware(
3145    resolver: Option<Arc<ForwardResolver>>,
3146    mut req: Request<Body>,
3147    next: Next,
3148) -> axum::response::Response {
3149    let direct = req
3150        .extensions()
3151        .get::<ConnectInfo<SocketAddr>>()
3152        .map(|ci| ci.0);
3153    let from_tls = req
3154        .extensions()
3155        .get::<ConnectInfo<TlsConnInfo>>()
3156        .map(|ci| ci.0.addr);
3157    if let Some(addr) = direct.or(from_tls) {
3158        if direct.is_none() {
3159            req.extensions_mut().insert(ConnectInfo(addr));
3160        }
3161        req.extensions_mut().insert(PeerAddr::new(addr));
3162        let client_ip = match &resolver {
3163            Some(r) => {
3164                crate::forwarded::resolve_client_ip(addr.ip(), req.headers(), &r.trusted, r.mode)
3165                    .unwrap_or_else(|reason| {
3166                        tracing::debug!(
3167                            reason = ?reason,
3168                            "forwarded-header resolution fell back to direct peer"
3169                        );
3170                        addr.ip()
3171                    })
3172            }
3173            None => addr.ip(),
3174        };
3175        req.extensions_mut().insert(ClientIp::new(client_ip));
3176    }
3177    next.run(req).await
3178}
3179
3180/// Parse a trusted-proxy entry: a CIDR (`10.0.0.0/8`) or a bare IP
3181/// (normalized to a `/32` / `/128` host network).
3182fn parse_proxy_net(entry: &str) -> Option<ipnet::IpNet> {
3183    if let Ok(net) = entry.parse::<ipnet::IpNet>() {
3184        return Some(net);
3185    }
3186    entry.parse::<IpAddr>().ok().map(ipnet::IpNet::from)
3187}
3188
3189/// Rate-limit key for the current request: the resolved [`ClientIp`]
3190/// when present, else the direct peer from either `ConnectInfo` form.
3191/// All four built-in limiters key through this helper.
3192pub(crate) fn limiter_client_ip(extensions: &axum::http::Extensions) -> Option<IpAddr> {
3193    if let Some(client) = extensions.get::<ClientIp>() {
3194        return Some(client.ip);
3195    }
3196    extensions
3197        .get::<ConnectInfo<SocketAddr>>()
3198        .map(|ci| ci.0.ip())
3199        .or_else(|| {
3200            extensions
3201                .get::<ConnectInfo<TlsConnInfo>>()
3202                .map(|ci| ci.0.addr.ip())
3203        })
3204}
3205
3206/// Per-IP rate limiter for `extra_router` routes, keyed by the direct
3207/// socket peer address. Same memory-bounded machinery as the tool
3208/// limiter ([`crate::rbac`]).
3209pub(crate) type ExtraRouteRateLimiter = BoundedKeyedLimiter<IpAddr>;
3210
3211/// Cap on distinct source IPs tracked by the extra-route limiter.
3212/// Mirrors the tool limiter's bound: memory stays bounded at saturation
3213/// via idle-prune + LRU eviction, at the cost of shared-fate fairness
3214/// under key spray (an attacker churning many IPs can reset quieter
3215/// legitimate IPs to fresh buckets).
3216const EXTRA_ROUTE_MAX_TRACKED_KEYS: usize = 10_000;
3217
3218/// Idle-eviction window for the extra-route limiter (15 minutes),
3219/// mirroring the tool limiter.
3220const EXTRA_ROUTE_IDLE_EVICTION: Duration = Duration::from_mins(15);
3221
3222/// Build the per-IP limiter for `extra_router` routes.
3223///
3224/// `per_minute` and `burst` are validated nonzero by
3225/// [`McpServerConfig::validate`]; the `NonZeroU32` fallbacks here are
3226/// defensive only. `burst` overrides governor's default bucket capacity
3227/// (burst = rate).
3228fn build_extra_route_rate_limiter(
3229    per_minute: u32,
3230    burst: Option<u32>,
3231) -> Arc<ExtraRouteRateLimiter> {
3232    let rate = std::num::NonZeroU32::new(per_minute.max(1)).unwrap_or(std::num::NonZeroU32::MIN);
3233    let mut quota = governor::Quota::per_minute(rate);
3234    if let Some(b) = burst.and_then(std::num::NonZeroU32::new) {
3235        quota = quota.allow_burst(b);
3236    }
3237    Arc::new(BoundedKeyedLimiter::new(
3238        quota,
3239        EXTRA_ROUTE_MAX_TRACKED_KEYS,
3240        EXTRA_ROUTE_IDLE_EVICTION,
3241    ))
3242}
3243
3244/// Per-IP rate limit middleware for `extra_router` routes.
3245///
3246/// Applied to the application-supplied router **before** it is merged
3247/// into the top-level router, so it wraps exactly the extra routes
3248/// (and their fallback, if any) and nothing else — `/mcp`, health,
3249/// admin, and OAuth endpoints are never affected. Outer layers (origin
3250/// check, peer-address normalization, security headers, metrics) still
3251/// wrap these routes and run first, so both `ConnectInfo` forms are
3252/// populated by the time this middleware reads them.
3253///
3254/// Semantics mirror the tool/auth limiters exactly: keyed by the
3255/// direct peer `IpAddr` (no `X-Forwarded-For`), fail-open when no peer
3256/// address is present (cannot happen under [`serve`]), and on limit a
3257/// plain-text 429 via [`McpxError::RateLimitedFor`] carrying a
3258/// `Retry-After` header (delta-seconds), consistent with every other
3259/// limiter in the crate.
3260async fn extra_route_rate_limit_middleware(
3261    limiter: Arc<ExtraRouteRateLimiter>,
3262    req: Request<Body>,
3263    next: Next,
3264) -> axum::response::Response {
3265    let peer_ip: Option<IpAddr> = limiter_client_ip(req.extensions());
3266    if let Some(ip) = peer_ip
3267        && let Err(wait) = limiter.check_key_wait(&ip)
3268    {
3269        tracing::warn!(%ip, "extra route request rate limited");
3270        return McpxError::RateLimitedFor {
3271            message: "too many requests to application routes from this source".into(),
3272            retry_after: wait,
3273        }
3274        .into_response();
3275    }
3276    next.run(req).await
3277}
3278
3279/// Per the MCP spec: if the Origin header is present and its value is not in
3280/// the allowed list, respond with 403 Forbidden. Requests without an Origin
3281/// header are allowed through (e.g. non-browser clients like curl, SDKs).
3282async fn origin_check_middleware(
3283    allowed: Arc<[String]>,
3284    log_request_headers: bool,
3285    req: Request<Body>,
3286    next: Next,
3287) -> axum::response::Response {
3288    let method = req.method().clone();
3289    let path = req.uri().path().to_owned();
3290
3291    log_incoming_request(&method, &path, req.headers(), log_request_headers);
3292
3293    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
3294        let origin_str = origin.to_str().unwrap_or("");
3295        if !allowed.iter().any(|a| a == origin_str) {
3296            tracing::warn!(
3297                origin = origin_str,
3298                %method,
3299                %path,
3300                allowed = ?&*allowed,
3301                "rejected request: Origin not allowed"
3302            );
3303            return (
3304                axum::http::StatusCode::FORBIDDEN,
3305                "Forbidden: Origin not allowed",
3306            )
3307                .into_response();
3308        }
3309    }
3310    next.run(req).await
3311}
3312
3313/// Emit a DEBUG log for an incoming request, optionally including the full
3314/// (redacted) header set.
3315fn log_incoming_request(
3316    method: &axum::http::Method,
3317    path: &str,
3318    headers: &axum::http::HeaderMap,
3319    log_request_headers: bool,
3320) {
3321    if log_request_headers {
3322        tracing::debug!(
3323            %method,
3324            %path,
3325            headers = %format_request_headers_for_log(headers),
3326            "incoming request"
3327        );
3328    } else {
3329        tracing::debug!(%method, %path, "incoming request");
3330    }
3331}
3332
3333fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
3334    headers
3335        .iter()
3336        .map(|(k, v)| {
3337            let name = k.as_str();
3338            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
3339                format!("{name}: [REDACTED]")
3340            } else {
3341                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
3342            }
3343        })
3344        .collect::<Vec<_>>()
3345        .join(", ")
3346}
3347
3348// -- stdio transport --
3349
3350/// Serve an MCP server over stdin/stdout (stdio transport).
3351///
3352/// # Security warnings
3353///
3354/// - **No authentication**: the parent process has full, unrestricted access.
3355/// - **No RBAC**: all tools are available regardless of policy.
3356/// - **No TLS**: messages travel over OS pipes in plaintext.
3357/// - **Single client**: only the parent process can connect.
3358/// - **No Origin validation**: not applicable to stdio.
3359///
3360/// Use this only when the MCP client spawns the server as a trusted subprocess
3361/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
3362/// use `serve()` (Streamable HTTP) instead.
3363///
3364/// # Errors
3365///
3366/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
3367/// transport disconnects unexpectedly.
3368// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
3369// macro expansion in this 18-line function (info/warn/info + two matches).
3370// There is nothing meaningful to extract; the allow stays.
3371#[allow(
3372    clippy::cognitive_complexity,
3373    reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
3374)]
3375pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
3376where
3377    H: ServerHandler + 'static,
3378{
3379    use rmcp::ServiceExt as _;
3380
3381    tracing::info!("stdio transport: serving on stdin/stdout");
3382    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
3383
3384    let transport = rmcp::transport::io::stdio();
3385
3386    let service = handler
3387        .serve(transport)
3388        .await
3389        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
3390
3391    if let Err(e) = service.waiting().await {
3392        tracing::warn!(error = %e, "stdio session ended with error");
3393    }
3394    tracing::info!("stdio session ended");
3395    Ok(())
3396}
3397
3398#[cfg(test)]
3399mod tests {
3400    #![allow(
3401        clippy::unwrap_used,
3402        clippy::expect_used,
3403        clippy::panic,
3404        clippy::indexing_slicing,
3405        clippy::unwrap_in_result,
3406        clippy::print_stdout,
3407        clippy::print_stderr,
3408        deprecated,
3409        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
3410    )]
3411    use std::{sync::Arc, time::Duration};
3412
3413    use axum::{
3414        body::Body,
3415        http::{Request, StatusCode, header},
3416        response::IntoResponse,
3417    };
3418    use http_body_util::BodyExt;
3419    use tower::ServiceExt as _;
3420
3421    use super::*;
3422
3423    // -- McpServerConfig --
3424
3425    #[test]
3426    fn server_config_new_defaults() {
3427        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
3428        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
3429        assert_eq!(cfg.name, "test-server");
3430        assert_eq!(cfg.version, "1.0.0");
3431        assert!(cfg.tls_cert_path.is_none());
3432        assert!(cfg.tls_key_path.is_none());
3433        assert!(cfg.auth.is_none());
3434        assert!(cfg.rbac.is_none());
3435        assert!(cfg.allowed_origins.is_empty());
3436        assert!(cfg.tool_rate_limit.is_none());
3437        assert!(cfg.readiness_check.is_none());
3438        assert_eq!(cfg.max_request_body, 1024 * 1024);
3439        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
3440        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
3441        assert!(!cfg.log_request_headers);
3442        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
3443        assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
3444    }
3445
3446    #[test]
3447    fn tls_handshake_builders_set_fields() {
3448        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3449            .with_tls_handshake_timeout(Duration::from_secs(3))
3450            .with_max_concurrent_tls_handshakes(64);
3451        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
3452        assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
3453    }
3454
3455    #[test]
3456    fn validate_rejects_zero_tls_handshake_timeout() {
3457        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3458            .with_tls_handshake_timeout(Duration::ZERO);
3459        let err = cfg.validate().expect_err("zero handshake timeout");
3460        assert!(err.to_string().contains("tls_handshake_timeout"));
3461    }
3462
3463    #[test]
3464    fn validate_rejects_zero_max_concurrent_tls_handshakes() {
3465        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3466            .with_max_concurrent_tls_handshakes(0);
3467        let err = cfg.validate().expect_err("zero handshake concurrency");
3468        assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
3469    }
3470
3471    #[test]
3472    fn validate_consumes_and_proves() {
3473        // Valid config -> Validated wrapper, original is consumed.
3474        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3475        let validated = cfg.validate().expect("valid config");
3476        // as_inner() gives read-only access to inner fields.
3477        assert_eq!(validated.as_inner().name, "test-server");
3478        // into_inner recovers the raw value.
3479        let raw = validated.into_inner();
3480        assert_eq!(raw.name, "test-server");
3481
3482        // Invalid config (zero max_request_body) -> Err.
3483        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3484        bad.max_request_body = 0;
3485        assert!(bad.validate().is_err(), "zero body cap must fail validate");
3486    }
3487
3488    #[test]
3489    fn validate_rejects_zero_max_concurrent_requests() {
3490        let cfg =
3491            McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
3492        let err = cfg.validate().expect_err("zero concurrency cap must fail");
3493        assert!(
3494            format!("{err}").contains("max_concurrent_requests"),
3495            "error should mention max_concurrent_requests, got: {err}"
3496        );
3497    }
3498
3499    #[test]
3500    fn validate_rejects_zero_max_tracked_keys() {
3501        // Defaults mirror auth::default_max_attempts / default_idle_eviction
3502        // (module-private in auth.rs); spelled out here for review clarity.
3503        let rl = crate::auth::RateLimitConfig {
3504            max_attempts_per_minute: 30,
3505            pre_auth_max_per_minute: None,
3506            max_tracked_keys: 0,
3507            idle_eviction: Duration::from_secs(15 * 60),
3508            burst: None,
3509            pre_auth_burst: None,
3510        };
3511        let auth_cfg = AuthConfig {
3512            enabled: true,
3513            api_keys: Vec::new(),
3514            mtls: None,
3515            rate_limit: Some(rl),
3516            #[cfg(feature = "oauth")]
3517            oauth: None,
3518        };
3519        let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
3520        let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
3521        assert!(
3522            format!("{err}").contains("max_tracked_keys"),
3523            "error should mention max_tracked_keys, got: {err}"
3524        );
3525    }
3526
3527    #[test]
3528    fn derive_allowed_hosts_includes_public_host() {
3529        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
3530        assert!(
3531            hosts.iter().any(|h| h == "mcp.example.com"),
3532            "public_url host must be allowed"
3533        );
3534    }
3535
3536    #[test]
3537    fn derive_allowed_hosts_includes_bind_authority() {
3538        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
3539        assert!(
3540            hosts.iter().any(|h| h == "127.0.0.1"),
3541            "bind host must be allowed"
3542        );
3543        assert!(
3544            hosts.iter().any(|h| h == "127.0.0.1:8080"),
3545            "bind authority must be allowed"
3546        );
3547    }
3548
3549    // -- healthz --
3550
3551    #[tokio::test]
3552    async fn healthz_returns_ok_json() {
3553        let resp = healthz().await.into_response();
3554        assert_eq!(resp.status(), StatusCode::OK);
3555        let body = resp.into_body().collect().await.unwrap().to_bytes();
3556        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3557        assert_eq!(json["status"], "ok");
3558        assert!(
3559            json.get("name").is_none(),
3560            "healthz must not expose server name"
3561        );
3562        assert!(
3563            json.get("version").is_none(),
3564            "healthz must not expose version"
3565        );
3566    }
3567
3568    // -- readyz --
3569
3570    #[tokio::test]
3571    async fn readyz_returns_ok_when_ready() {
3572        let check: ReadinessCheck =
3573            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
3574        let resp = readyz(check).await.into_response();
3575        assert_eq!(resp.status(), StatusCode::OK);
3576        let body = resp.into_body().collect().await.unwrap().to_bytes();
3577        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3578        assert_eq!(json["ready"], true);
3579        assert!(
3580            json.get("name").is_none(),
3581            "readyz must not expose server name"
3582        );
3583        assert!(
3584            json.get("version").is_none(),
3585            "readyz must not expose version"
3586        );
3587        assert_eq!(json["db"], "connected");
3588    }
3589
3590    #[tokio::test]
3591    async fn readyz_returns_503_when_not_ready() {
3592        let check: ReadinessCheck =
3593            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3594        let resp = readyz(check).await.into_response();
3595        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3596    }
3597
3598    #[tokio::test]
3599    async fn readyz_returns_503_when_ready_missing() {
3600        let check: ReadinessCheck =
3601            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3602        let resp = readyz(check).await.into_response();
3603        // Missing "ready" field defaults to false -> 503
3604        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3605    }
3606
3607    // -- normalize_peer_addr_middleware / PeerAddr --
3608
3609    /// Build a test router that reports the request's peer-address
3610    /// extensions as `"<ConnectInfo>|<PeerAddr>"` (empty when absent).
3611    fn peer_probe_router() -> axum::Router {
3612        async fn probe(req: Request<Body>) -> String {
3613            let ci = req
3614                .extensions()
3615                .get::<ConnectInfo<SocketAddr>>()
3616                .map(|c| c.0.to_string())
3617                .unwrap_or_default();
3618            let pa = req
3619                .extensions()
3620                .get::<PeerAddr>()
3621                .map(|p| p.addr.to_string())
3622                .unwrap_or_default();
3623            format!("{ci}|{pa}")
3624        }
3625        axum::Router::new()
3626            .route("/probe", axum::routing::get(probe))
3627            .layer(axum::middleware::from_fn(|req, next| {
3628                normalize_peer_addr_middleware(None, req, next)
3629            }))
3630    }
3631
3632    async fn body_string(resp: axum::response::Response) -> String {
3633        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
3634        String::from_utf8(bytes.to_vec()).unwrap()
3635    }
3636
3637    #[tokio::test]
3638    async fn normalize_preserves_existing_connect_info_and_mirrors_peer_addr() {
3639        // Precedence proof: when both extensions exist with DIFFERENT
3640        // addresses, ConnectInfo<SocketAddr> wins and is never overwritten.
3641        let plain: SocketAddr = "10.0.0.1:1111".parse().unwrap();
3642        let tls: SocketAddr = "10.0.0.2:2222".parse().unwrap();
3643        let req = Request::builder()
3644            .uri("/probe")
3645            .extension(ConnectInfo(plain))
3646            .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3647            .body(Body::empty())
3648            .unwrap();
3649        let resp = peer_probe_router().oneshot(req).await.unwrap();
3650        assert_eq!(resp.status(), StatusCode::OK);
3651        assert_eq!(body_string(resp).await, format!("{plain}|{plain}"));
3652    }
3653
3654    #[tokio::test]
3655    async fn normalize_inserts_connect_info_and_peer_addr_from_tls() {
3656        let tls: SocketAddr = "192.168.1.7:50443".parse().unwrap();
3657        let req = Request::builder()
3658            .uri("/probe")
3659            .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3660            .body(Body::empty())
3661            .unwrap();
3662        let resp = peer_probe_router().oneshot(req).await.unwrap();
3663        assert_eq!(resp.status(), StatusCode::OK);
3664        assert_eq!(body_string(resp).await, format!("{tls}|{tls}"));
3665    }
3666
3667    #[tokio::test]
3668    async fn normalize_no_op_without_any_connect_info() {
3669        let req = Request::builder()
3670            .uri("/probe")
3671            .body(Body::empty())
3672            .unwrap();
3673        let resp = peer_probe_router().oneshot(req).await.unwrap();
3674        assert_eq!(resp.status(), StatusCode::OK);
3675        assert_eq!(body_string(resp).await, "|");
3676    }
3677
3678    #[tokio::test]
3679    async fn peer_addr_extractor_rejects_when_absent() {
3680        async fn h(peer: PeerAddr) -> String {
3681            peer.addr.to_string()
3682        }
3683        let app = axum::Router::new().route("/p", axum::routing::get(h));
3684        let req = Request::builder().uri("/p").body(Body::empty()).unwrap();
3685        let resp = app.oneshot(req).await.unwrap();
3686        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
3687    }
3688
3689    #[tokio::test]
3690    async fn peer_addr_extractor_returns_value_when_present() {
3691        async fn h(peer: PeerAddr) -> String {
3692            peer.addr.to_string()
3693        }
3694        let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
3695        let app = axum::Router::new().route("/p", axum::routing::get(h));
3696        let req = Request::builder()
3697            .uri("/p")
3698            .extension(PeerAddr::new(addr))
3699            .body(Body::empty())
3700            .unwrap();
3701        let resp = app.oneshot(req).await.unwrap();
3702        assert_eq!(resp.status(), StatusCode::OK);
3703        assert_eq!(body_string(resp).await, addr.to_string());
3704    }
3705
3706    #[tokio::test]
3707    async fn peer_addr_via_extension_extractor() {
3708        async fn h(axum::Extension(peer): axum::Extension<PeerAddr>) -> String {
3709            peer.addr.to_string()
3710        }
3711        let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
3712        let app = axum::Router::new().route("/p", axum::routing::get(h));
3713        let req = Request::builder()
3714            .uri("/p")
3715            .extension(PeerAddr::new(addr))
3716            .body(Body::empty())
3717            .unwrap();
3718        let resp = app.oneshot(req).await.unwrap();
3719        assert_eq!(resp.status(), StatusCode::OK);
3720        assert_eq!(body_string(resp).await, addr.to_string());
3721    }
3722
3723    // -- extra_route_rate_limit_middleware --
3724
3725    /// Probe router with the extra-route limiter installed, mirroring
3726    /// the layer-before-merge wiring in `build_app_router`.
3727    fn limited_router(per_minute: u32) -> axum::Router {
3728        limited_router_with_burst(per_minute, None)
3729    }
3730
3731    /// Probe router with an explicit burst capacity.
3732    fn limited_router_with_burst(per_minute: u32, burst: Option<u32>) -> axum::Router {
3733        let limiter = build_extra_route_rate_limiter(per_minute, burst);
3734        axum::Router::new()
3735            .route("/limited", axum::routing::get(|| async { "ok" }))
3736            .layer(axum::middleware::from_fn(move |req, next| {
3737                let l = Arc::clone(&limiter);
3738                extra_route_rate_limit_middleware(l, req, next)
3739            }))
3740    }
3741
3742    fn limited_req(ip: &str) -> Request<Body> {
3743        let addr: SocketAddr = format!("{ip}:40000").parse().unwrap();
3744        Request::builder()
3745            .uri("/limited")
3746            .extension(ConnectInfo(addr))
3747            .body(Body::empty())
3748            .unwrap()
3749    }
3750
3751    #[tokio::test]
3752    async fn extra_route_limiter_denies_over_quota() {
3753        let app = limited_router(2);
3754        for i in 0..2 {
3755            let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3756            assert_eq!(resp.status(), StatusCode::OK, "request {i} should pass");
3757        }
3758        let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3759        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3760        let body = body_string(resp).await;
3761        assert!(
3762            body.contains("too many requests to application routes"),
3763            "deny body should match the limiter message, got: {body}"
3764        );
3765    }
3766
3767    #[tokio::test]
3768    async fn extra_route_limiter_isolates_keys() {
3769        let app = limited_router(2);
3770        for _ in 0..2 {
3771            let resp = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3772            assert_eq!(resp.status(), StatusCode::OK);
3773        }
3774        let exhausted = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3775        assert_eq!(exhausted.status(), StatusCode::TOO_MANY_REQUESTS);
3776        // A different source IP still has a fresh bucket.
3777        let other = app.clone().oneshot(limited_req("10.3.3.3")).await.unwrap();
3778        assert_eq!(other.status(), StatusCode::OK);
3779    }
3780
3781    #[tokio::test]
3782    async fn extra_route_limiter_fails_open_without_peer() {
3783        let app = limited_router(1);
3784        for i in 0..3 {
3785            let req = Request::builder()
3786                .uri("/limited")
3787                .body(Body::empty())
3788                .unwrap();
3789            let resp = app.clone().oneshot(req).await.unwrap();
3790            assert_eq!(
3791                resp.status(),
3792                StatusCode::OK,
3793                "request {i} should fail open"
3794            );
3795        }
3796    }
3797
3798    #[tokio::test]
3799    async fn extra_route_limiter_extracts_tls_conn_info() {
3800        let app = limited_router(2);
3801        let mk = || {
3802            let addr: SocketAddr = "192.168.9.9:55555".parse().unwrap();
3803            Request::builder()
3804                .uri("/limited")
3805                .extension(ConnectInfo(TlsConnInfo::new(addr, None)))
3806                .body(Body::empty())
3807                .unwrap()
3808        };
3809        for _ in 0..2 {
3810            assert_eq!(
3811                app.clone().oneshot(mk()).await.unwrap().status(),
3812                StatusCode::OK
3813            );
3814        }
3815        let resp = app.clone().oneshot(mk()).await.unwrap();
3816        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3817    }
3818
3819    #[test]
3820    fn validate_rejects_zero_extra_route_rate_limit() {
3821        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3822            .with_extra_route_rate_limit(0);
3823        let err = cfg.validate().expect_err("zero extra route rate limit");
3824        assert!(err.to_string().contains("extra_route_rate_limit"));
3825    }
3826
3827    #[tokio::test]
3828    async fn extra_route_limiter_burst_allows_initial_spike() {
3829        let app = limited_router_with_burst(1, Some(3));
3830        for i in 0..3 {
3831            let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
3832            assert_eq!(resp.status(), StatusCode::OK, "burst request {i}");
3833        }
3834        let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
3835        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3836    }
3837
3838    #[tokio::test]
3839    async fn extra_route_limiter_deny_sets_retry_after() {
3840        let app = limited_router(1);
3841        let ok = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
3842        assert_eq!(ok.status(), StatusCode::OK);
3843        let denied = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
3844        assert_eq!(denied.status(), StatusCode::TOO_MANY_REQUESTS);
3845        let retry_after = denied
3846            .headers()
3847            .get(header::RETRY_AFTER)
3848            .expect("Retry-After present")
3849            .to_str()
3850            .unwrap()
3851            .parse::<u64>()
3852            .unwrap();
3853        assert!(retry_after >= 1, "delta-seconds must be >= 1");
3854    }
3855
3856    #[test]
3857    fn validate_rejects_zero_burst_knobs() {
3858        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3859            .with_tool_rate_limit(10)
3860            .with_tool_rate_limit_burst(0)
3861            .validate()
3862            .expect_err("zero tool burst");
3863        assert!(err.to_string().contains("tool_rate_limit_burst"));
3864
3865        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3866            .with_extra_route_rate_limit(10)
3867            .with_extra_route_rate_limit_burst(0)
3868            .validate()
3869            .expect_err("zero extra route burst");
3870        assert!(err.to_string().contains("extra_route_rate_limit_burst"));
3871    }
3872
3873    #[test]
3874    fn validate_rejects_orphan_burst_knobs() {
3875        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3876            .with_tool_rate_limit_burst(5)
3877            .validate()
3878            .expect_err("orphan tool burst");
3879        assert!(err.to_string().contains("requires tool_rate_limit"));
3880
3881        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3882            .with_extra_route_rate_limit_burst(5)
3883            .validate()
3884            .expect_err("orphan extra route burst");
3885        assert!(err.to_string().contains("requires extra_route_rate_limit"));
3886    }
3887
3888    #[test]
3889    fn validate_rejects_zero_auth_bursts() {
3890        let auth = AuthConfig::with_keys(vec![])
3891            .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_burst(0));
3892        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3893            .with_auth(auth)
3894            .validate()
3895            .expect_err("zero auth burst");
3896        assert!(err.to_string().contains("rate_limit.burst"));
3897
3898        let auth = AuthConfig::with_keys(vec![])
3899            .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(0));
3900        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
3901            .with_auth(auth)
3902            .validate()
3903            .expect_err("zero pre-auth burst");
3904        assert!(err.to_string().contains("pre_auth_burst"));
3905    }
3906
3907    /// `pre_auth_burst` without `pre_auth_max_per_minute` is LEGAL: the
3908    /// pre-auth base rate always resolves (max_attempts_per_minute x 10).
3909    #[test]
3910    fn validate_accepts_pre_auth_burst_without_explicit_pre_auth_rate() {
3911        let auth = AuthConfig::with_keys(vec![])
3912            .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(50));
3913        let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_auth(auth);
3914        assert!(cfg.validate().is_ok(), "pre_auth_burst has no orphan rule");
3915    }
3916
3917    // -- trusted-forwarder mode (ClientIp / ForwardedHeaderMode) --
3918
3919    fn forward_resolver(trusted: &[&str], mode: ForwardedHeaderMode) -> Arc<ForwardResolver> {
3920        Arc::new(ForwardResolver {
3921            trusted: trusted.iter().map(|s| s.parse().unwrap()).collect(),
3922            mode,
3923        })
3924    }
3925
3926    /// Probe router reporting `"<PeerAddr ip>|<ClientIp>"`.
3927    fn forwarded_probe_router(resolver: Option<Arc<ForwardResolver>>) -> axum::Router {
3928        async fn probe(req: Request<Body>) -> String {
3929            let pa = req
3930                .extensions()
3931                .get::<PeerAddr>()
3932                .map(|p| p.addr.ip().to_string())
3933                .unwrap_or_default();
3934            let ci = req
3935                .extensions()
3936                .get::<ClientIp>()
3937                .map(|c| c.ip.to_string())
3938                .unwrap_or_default();
3939            format!("{pa}|{ci}")
3940        }
3941        axum::Router::new()
3942            .route("/probe", axum::routing::get(probe))
3943            .layer(axum::middleware::from_fn(move |req, next| {
3944                let r = resolver.clone();
3945                normalize_peer_addr_middleware(r, req, next)
3946            }))
3947    }
3948
3949    fn probe_req(peer: &str, header: Option<(&str, &str)>) -> Request<Body> {
3950        let addr: SocketAddr = peer.parse().unwrap();
3951        let mut builder = Request::builder()
3952            .uri("/probe")
3953            .extension(ConnectInfo(addr));
3954        if let Some((name, value)) = header {
3955            builder = builder.header(name, value);
3956        }
3957        builder.body(Body::empty()).unwrap()
3958    }
3959
3960    #[tokio::test]
3961    async fn client_ip_equals_direct_without_resolver() {
3962        let app = forwarded_probe_router(None);
3963        let resp = app
3964            .oneshot(probe_req(
3965                "10.1.2.3:4444",
3966                Some(("x-forwarded-for", "203.0.113.7")),
3967            ))
3968            .await
3969            .unwrap();
3970        assert_eq!(
3971            body_string(resp).await,
3972            "10.1.2.3|10.1.2.3",
3973            "feature off: header ignored, ClientIp == direct"
3974        );
3975    }
3976
3977    #[tokio::test]
3978    async fn client_ip_resolved_for_trusted_peer() {
3979        let app = forwarded_probe_router(Some(forward_resolver(
3980            &["10.0.0.0/8"],
3981            ForwardedHeaderMode::XForwardedFor,
3982        )));
3983        let resp = app
3984            .oneshot(probe_req(
3985                "10.0.0.1:9999",
3986                Some(("x-forwarded-for", "203.0.113.7")),
3987            ))
3988            .await
3989            .unwrap();
3990        assert_eq!(
3991            body_string(resp).await,
3992            "10.0.0.1|203.0.113.7",
3993            "PeerAddr stays direct while ClientIp resolves"
3994        );
3995    }
3996
3997    #[tokio::test]
3998    async fn client_ip_falls_back_to_direct_on_malformed_header() {
3999        let app = forwarded_probe_router(Some(forward_resolver(
4000            &["10.0.0.0/8"],
4001            ForwardedHeaderMode::XForwardedFor,
4002        )));
4003        let resp = app
4004            .oneshot(probe_req(
4005                "10.0.0.1:9999",
4006                Some(("x-forwarded-for", "not-an-ip")),
4007            ))
4008            .await
4009            .unwrap();
4010        assert_eq!(
4011            body_string(resp).await,
4012            "10.0.0.1|10.0.0.1",
4013            "malformed chain falls back to the direct peer"
4014        );
4015    }
4016
4017    #[test]
4018    fn forwarded_header_mode_deserializes_kebab_case() {
4019        #[derive(serde::Deserialize)]
4020        struct Wrapper {
4021            mode: ForwardedHeaderMode,
4022        }
4023        let w: Wrapper = toml::from_str(r#"mode = "x-forwarded-for""#).unwrap();
4024        assert_eq!(w.mode, ForwardedHeaderMode::XForwardedFor);
4025        let w: Wrapper = toml::from_str(r#"mode = "forwarded""#).unwrap();
4026        assert_eq!(w.mode, ForwardedHeaderMode::Forwarded);
4027        assert!(
4028            toml::from_str::<Wrapper>(r#"mode = "XForwardedFor""#).is_err(),
4029            "PascalCase wire value must be rejected"
4030        );
4031    }
4032
4033    #[test]
4034    fn validate_rejects_bad_trusted_proxy_entry() {
4035        let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4036            .with_trusted_proxies(["not-a-cidr"]);
4037        let err = cfg.validate().expect_err("bad CIDR");
4038        assert!(err.to_string().contains("trusted_proxies"));
4039    }
4040
4041    #[test]
4042    fn validate_accepts_cidr_and_bare_ip_proxy_entries() {
4043        let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_trusted_proxies([
4044            "10.0.0.0/8",
4045            "192.0.2.1",
4046            "2001:db8::1",
4047        ]);
4048        assert!(cfg.validate().is_ok(), "CIDRs and bare IPs are accepted");
4049    }
4050
4051    #[test]
4052    fn validate_rejects_forwarded_header_without_proxies() {
4053        let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4054            .with_forwarded_header(ForwardedHeaderMode::Forwarded);
4055        let err = cfg.validate().expect_err("mode without proxies");
4056        assert!(err.to_string().contains("requires trusted_proxies"));
4057    }
4058
4059    // -- origin_check_middleware --
4060
4061    /// Build a test router with origin check middleware and a simple handler.
4062    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
4063        let allowed: Arc<[String]> = Arc::from(origins);
4064        axum::Router::new()
4065            .route("/test", axum::routing::get(|| async { "ok" }))
4066            .layer(axum::middleware::from_fn(move |req, next| {
4067                let a = Arc::clone(&allowed);
4068                origin_check_middleware(a, log_request_headers, req, next)
4069            }))
4070    }
4071
4072    #[tokio::test]
4073    async fn origin_allowed_passes() {
4074        let app = origin_router(vec!["http://localhost:3000".into()], false);
4075        let req = Request::builder()
4076            .uri("/test")
4077            .header(header::ORIGIN, "http://localhost:3000")
4078            .body(Body::empty())
4079            .unwrap();
4080        let resp = app.oneshot(req).await.unwrap();
4081        assert_eq!(resp.status(), StatusCode::OK);
4082    }
4083
4084    #[tokio::test]
4085    async fn origin_rejected_returns_403() {
4086        let app = origin_router(vec!["http://localhost:3000".into()], false);
4087        let req = Request::builder()
4088            .uri("/test")
4089            .header(header::ORIGIN, "http://evil.com")
4090            .body(Body::empty())
4091            .unwrap();
4092        let resp = app.oneshot(req).await.unwrap();
4093        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
4094    }
4095
4096    #[tokio::test]
4097    async fn no_origin_header_passes() {
4098        let app = origin_router(vec!["http://localhost:3000".into()], false);
4099        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4100        let resp = app.oneshot(req).await.unwrap();
4101        assert_eq!(resp.status(), StatusCode::OK);
4102    }
4103
4104    #[tokio::test]
4105    async fn empty_allowlist_rejects_any_origin() {
4106        let app = origin_router(vec![], false);
4107        let req = Request::builder()
4108            .uri("/test")
4109            .header(header::ORIGIN, "http://anything.com")
4110            .body(Body::empty())
4111            .unwrap();
4112        let resp = app.oneshot(req).await.unwrap();
4113        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
4114    }
4115
4116    #[tokio::test]
4117    async fn empty_allowlist_passes_without_origin() {
4118        let app = origin_router(vec![], false);
4119        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4120        let resp = app.oneshot(req).await.unwrap();
4121        assert_eq!(resp.status(), StatusCode::OK);
4122    }
4123
4124    #[test]
4125    fn format_request_headers_redacts_sensitive_values() {
4126        let mut headers = axum::http::HeaderMap::new();
4127        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
4128        headers.insert("cookie", "sid=abc".parse().unwrap());
4129        headers.insert("x-request-id", "req-123".parse().unwrap());
4130
4131        let out = format_request_headers_for_log(&headers);
4132        assert!(out.contains("authorization: [REDACTED]"));
4133        assert!(out.contains("cookie: [REDACTED]"));
4134        assert!(out.contains("x-request-id: req-123"));
4135        assert!(!out.contains("secret-token"));
4136    }
4137
4138    // -- security_headers_middleware --
4139
4140    fn security_router(is_tls: bool) -> axum::Router {
4141        security_router_with(is_tls, SecurityHeadersConfig::default())
4142    }
4143
4144    fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
4145        let cfg = Arc::new(cfg);
4146        axum::Router::new()
4147            .route("/test", axum::routing::get(|| async { "ok" }))
4148            .layer(axum::middleware::from_fn(move |req, next| {
4149                let c = Arc::clone(&cfg);
4150                security_headers_middleware(is_tls, c, req, next)
4151            }))
4152    }
4153
4154    #[tokio::test]
4155    async fn security_headers_set_on_response() {
4156        let app = security_router(false);
4157        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4158        let resp = app.oneshot(req).await.unwrap();
4159        assert_eq!(resp.status(), StatusCode::OK);
4160
4161        let h = resp.headers();
4162        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
4163        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
4164        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
4165        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
4166        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
4167        assert_eq!(
4168            h.get("cross-origin-resource-policy").unwrap(),
4169            "same-origin"
4170        );
4171        assert_eq!(
4172            h.get("cross-origin-embedder-policy").unwrap(),
4173            "require-corp"
4174        );
4175        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
4176        assert!(
4177            h.get("permissions-policy")
4178                .unwrap()
4179                .to_str()
4180                .unwrap()
4181                .contains("camera=()"),
4182            "permissions-policy must restrict browser features"
4183        );
4184        assert_eq!(
4185            h.get("content-security-policy").unwrap(),
4186            "default-src 'none'; frame-ancestors 'none'"
4187        );
4188        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
4189        // No HSTS when TLS is off.
4190        assert!(h.get("strict-transport-security").is_none());
4191    }
4192
4193    #[tokio::test]
4194    async fn hsts_set_when_tls_enabled() {
4195        let app = security_router(true);
4196        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4197        let resp = app.oneshot(req).await.unwrap();
4198
4199        let hsts = resp.headers().get("strict-transport-security").unwrap();
4200        assert!(
4201            hsts.to_str().unwrap().contains("max-age=63072000"),
4202            "HSTS must set 2-year max-age"
4203        );
4204    }
4205
4206    // -- SecurityHeadersConfig validation + override semantics --
4207
4208    /// Build a minimal config with a custom SecurityHeadersConfig and
4209    /// drive it through `check()`. Returns the result so individual
4210    /// tests can assert on success or specific error messages.
4211    fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
4212        let cfg =
4213            McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
4214        cfg.check()
4215    }
4216
4217    #[test]
4218    fn security_headers_config_default_validates() {
4219        check_with_security_headers(SecurityHeadersConfig::default())
4220            .expect("default SecurityHeadersConfig must validate");
4221    }
4222
4223    #[test]
4224    fn security_headers_config_validate_accepts_empty_string() {
4225        // All twelve fields explicitly set to "" -> omit-everything mode.
4226        let h = SecurityHeadersConfig {
4227            x_content_type_options: Some(String::new()),
4228            x_frame_options: Some(String::new()),
4229            cache_control: Some(String::new()),
4230            referrer_policy: Some(String::new()),
4231            cross_origin_opener_policy: Some(String::new()),
4232            cross_origin_resource_policy: Some(String::new()),
4233            cross_origin_embedder_policy: Some(String::new()),
4234            permissions_policy: Some(String::new()),
4235            x_permitted_cross_domain_policies: Some(String::new()),
4236            content_security_policy: Some(String::new()),
4237            x_dns_prefetch_control: Some(String::new()),
4238            strict_transport_security: Some(String::new()),
4239        };
4240        check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
4241    }
4242
4243    #[test]
4244    fn security_headers_config_validate_rejects_bad_value() {
4245        // 0x07 (BEL) is not a valid HTTP header value char.
4246        let h = SecurityHeadersConfig {
4247            referrer_policy: Some("\u{0007}".into()),
4248            ..SecurityHeadersConfig::default()
4249        };
4250        let err = check_with_security_headers(h)
4251            .expect_err("control char in referrer_policy must reject");
4252        let msg = err.to_string();
4253        assert!(
4254            msg.contains("referrer_policy"),
4255            "error must name the offending field, got: {msg}"
4256        );
4257    }
4258
4259    #[test]
4260    fn security_headers_config_validate_rejects_hsts_preload() {
4261        let h = SecurityHeadersConfig {
4262            strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
4263            ..SecurityHeadersConfig::default()
4264        };
4265        let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
4266        let msg = err.to_string();
4267        assert!(
4268            msg.contains("strict_transport_security"),
4269            "error must name the field, got: {msg}"
4270        );
4271        assert!(
4272            msg.to_lowercase().contains("preload"),
4273            "error must mention `preload`, got: {msg}"
4274        );
4275    }
4276
4277    #[test]
4278    fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
4279        // Case-insensitive match.
4280        let h = SecurityHeadersConfig {
4281            strict_transport_security: Some("max-age=600; PRELOAD".into()),
4282            ..SecurityHeadersConfig::default()
4283        };
4284        check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
4285    }
4286
4287    #[tokio::test]
4288    async fn security_headers_override_honored() {
4289        // Override X-Frame-Options to SAMEORIGIN.
4290        let h = SecurityHeadersConfig {
4291            x_frame_options: Some("SAMEORIGIN".into()),
4292            ..SecurityHeadersConfig::default()
4293        };
4294        let app = security_router_with(false, h);
4295        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4296        let resp = app.oneshot(req).await.unwrap();
4297        assert_eq!(resp.status(), StatusCode::OK);
4298
4299        let xfo = resp.headers().get("x-frame-options").unwrap();
4300        assert_eq!(xfo, "SAMEORIGIN");
4301    }
4302
4303    #[tokio::test]
4304    async fn security_headers_empty_string_omits() {
4305        // Empty string on referrer-policy -> header absent.
4306        let h = SecurityHeadersConfig {
4307            referrer_policy: Some(String::new()),
4308            ..SecurityHeadersConfig::default()
4309        };
4310        let app = security_router_with(false, h);
4311        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4312        let resp = app.oneshot(req).await.unwrap();
4313        assert_eq!(resp.status(), StatusCode::OK);
4314
4315        assert!(
4316            resp.headers().get("referrer-policy").is_none(),
4317            "Some(\"\") must omit the header"
4318        );
4319        // Other defaults should still be present.
4320        assert_eq!(
4321            resp.headers().get("x-content-type-options").unwrap(),
4322            "nosniff"
4323        );
4324    }
4325
4326    #[tokio::test]
4327    async fn security_headers_hsts_only_when_tls() {
4328        // HSTS override is irrelevant when TLS is off.
4329        let h = SecurityHeadersConfig {
4330            strict_transport_security: Some("max-age=600".into()),
4331            ..SecurityHeadersConfig::default()
4332        };
4333        let app = security_router_with(false, h);
4334        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4335        let resp = app.oneshot(req).await.unwrap();
4336        assert!(
4337            resp.headers().get("strict-transport-security").is_none(),
4338            "HSTS must remain absent on plaintext deployments even with override"
4339        );
4340    }
4341
4342    // -- oauth_token_cache_headers_middleware --
4343
4344    #[cfg(feature = "oauth")]
4345    #[tokio::test]
4346    async fn oauth_token_cache_headers_set_pragma_and_vary() {
4347        let app = axum::Router::new()
4348            .route("/token", axum::routing::post(|| async { "{}" }))
4349            .layer(axum::middleware::from_fn(
4350                oauth_token_cache_headers_middleware,
4351            ));
4352        let req = Request::builder()
4353            .method("POST")
4354            .uri("/token")
4355            .body(Body::from("{}"))
4356            .unwrap();
4357        let resp = app.oneshot(req).await.unwrap();
4358        assert_eq!(resp.status(), StatusCode::OK);
4359
4360        let h = resp.headers();
4361        assert_eq!(
4362            h.get("pragma").unwrap(),
4363            "no-cache",
4364            "RFC 6749 §5.1: token responses must set Pragma: no-cache"
4365        );
4366        let vary_values: Vec<String> = h
4367            .get_all("vary")
4368            .iter()
4369            .filter_map(|v| v.to_str().ok().map(str::to_owned))
4370            .collect();
4371        assert!(
4372            vary_values
4373                .iter()
4374                .any(|v| v.eq_ignore_ascii_case("Authorization")),
4375            "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
4376        );
4377    }
4378
4379    #[cfg(feature = "oauth")]
4380    #[tokio::test]
4381    async fn oauth_token_cache_headers_preserve_existing_vary() {
4382        // Simulates a handler/layer that already set `Vary: Accept-Encoding`
4383        // (e.g. compression). Our middleware must APPEND, not REPLACE.
4384        let app = axum::Router::new()
4385            .route(
4386                "/token",
4387                axum::routing::post(|| async {
4388                    axum::response::Response::builder()
4389                        .header("vary", "Accept-Encoding")
4390                        .body(axum::body::Body::from("{}"))
4391                        .unwrap()
4392                }),
4393            )
4394            .layer(axum::middleware::from_fn(
4395                oauth_token_cache_headers_middleware,
4396            ));
4397        let req = Request::builder()
4398            .method("POST")
4399            .uri("/token")
4400            .body(Body::empty())
4401            .unwrap();
4402        let resp = app.oneshot(req).await.unwrap();
4403
4404        let vary: Vec<String> = resp
4405            .headers()
4406            .get_all("vary")
4407            .iter()
4408            .filter_map(|v| v.to_str().ok().map(str::to_owned))
4409            .collect();
4410        assert!(
4411            vary.iter().any(|v| v.contains("Accept-Encoding")),
4412            "must preserve pre-existing Vary value, got {vary:?}"
4413        );
4414        assert!(
4415            vary.iter().any(|v| v.contains("Authorization")),
4416            "must append Authorization to Vary, got {vary:?}"
4417        );
4418    }
4419
4420    // -- version endpoint --
4421
4422    #[test]
4423    fn version_payload_contains_expected_fields() {
4424        let v = version_payload("my-server", "1.2.3");
4425        assert_eq!(v["name"], "my-server");
4426        assert_eq!(v["version"], "1.2.3");
4427        assert!(v["build_git_sha"].is_string());
4428        assert!(v["build_timestamp"].is_string());
4429        assert!(v["rust_version"].is_string());
4430        assert!(v["mcpx_version"].is_string());
4431    }
4432
4433    // -- concurrency limit layer --
4434
4435    #[tokio::test]
4436    async fn concurrency_limit_layer_composes_and_serves() {
4437        // We only assert the layer stack compiles and a single request
4438        // below the cap still succeeds. True back-pressure behaviour
4439        // requires a live HTTP server and is covered by integration tests.
4440        let app = axum::Router::new()
4441            .route("/ok", axum::routing::get(|| async { "ok" }))
4442            .layer(
4443                tower::ServiceBuilder::new()
4444                    .layer(axum::error_handling::HandleErrorLayer::new(
4445                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
4446                    ))
4447                    .layer(tower::load_shed::LoadShedLayer::new())
4448                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
4449            );
4450        let resp = app
4451            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
4452            .await
4453            .unwrap();
4454        assert_eq!(resp.status(), StatusCode::OK);
4455    }
4456
4457    // -- compression layer --
4458
4459    #[tokio::test]
4460    async fn compression_layer_gzip_encodes_response() {
4461        use tower_http::compression::Predicate as _;
4462
4463        let big_body = "a".repeat(4096);
4464        let app = axum::Router::new()
4465            .route(
4466                "/big",
4467                axum::routing::get(move || {
4468                    let body = big_body.clone();
4469                    async move { body }
4470                }),
4471            )
4472            .layer(
4473                tower_http::compression::CompressionLayer::new()
4474                    .gzip(true)
4475                    .br(true)
4476                    .compress_when(
4477                        tower_http::compression::DefaultPredicate::new()
4478                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
4479                    ),
4480            );
4481
4482        let req = Request::builder()
4483            .uri("/big")
4484            .header(header::ACCEPT_ENCODING, "gzip")
4485            .body(Body::empty())
4486            .unwrap();
4487        let resp = app.oneshot(req).await.unwrap();
4488        assert_eq!(resp.status(), StatusCode::OK);
4489        assert_eq!(
4490            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
4491            "gzip"
4492        );
4493    }
4494
4495    // -- TlsListener handshake timeout --
4496
4497    #[tokio::test]
4498    async fn tls_handshake_timeout_reaps_idle_connections() {
4499        use tokio::io::AsyncReadExt as _;
4500
4501        let _ = rustls::crypto::ring::default_provider().install_default();
4502
4503        // Self-signed cert material on disk (TlsListener::new takes paths).
4504        let key = rcgen::KeyPair::generate().expect("generate key");
4505        let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
4506            .expect("cert params")
4507            .self_signed(&key)
4508            .expect("self-signed cert");
4509        let dir = std::env::temp_dir().join(format!(
4510            "rmcp-server-kit-hs-timeout-{}",
4511            std::time::SystemTime::now()
4512                .duration_since(std::time::UNIX_EPOCH)
4513                .expect("clock after epoch")
4514                .as_nanos()
4515        ));
4516        tokio::fs::create_dir_all(&dir).await.expect("temp dir");
4517        let cert_path = dir.join("server.crt");
4518        let key_path = dir.join("server.key");
4519        tokio::fs::write(&cert_path, cert.pem())
4520            .await
4521            .expect("write cert");
4522        tokio::fs::write(&key_path, key.serialize_pem())
4523            .await
4524            .expect("write key");
4525
4526        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
4527        let tls = TlsListener::new(
4528            listener,
4529            &cert_path,
4530            &key_path,
4531            None,
4532            None,
4533            Duration::from_millis(200),
4534            8, // custom concurrency cap: proves the plumbing end-to-end
4535        )
4536        .expect("tls listener");
4537        let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
4538
4539        // Connect and send NOTHING: the handshake worker must time out
4540        // after 200ms and drop the stream, which the client observes as
4541        // EOF or a reset well within the 2s deadline.
4542        let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
4543        let mut buf = [0_u8; 16];
4544        let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
4545            .await
4546            .expect("server must reap the idle handshake within its timeout");
4547        match read {
4548            Ok(0) | Err(_) => {} // EOF or reset: connection was dropped.
4549            Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
4550        }
4551
4552        drop(tls);
4553    }
4554}