Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::{IpAddr, SocketAddr},
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{
12    body::Body,
13    extract::{ConnectInfo, Request},
14    middleware::Next,
15    response::IntoResponse,
16};
17use rmcp::{
18    ServerHandler,
19    transport::streamable_http_server::{
20        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21    },
22};
23use rustls::RootCertStore;
24use tokio::{
25    net::TcpListener,
26    sync::{Semaphore, mpsc},
27};
28use tokio_util::sync::CancellationToken;
29
30use crate::{
31    auth::{
32        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
33        build_rate_limiter, extract_mtls_identity,
34    },
35    bounded_limiter::BoundedKeyedLimiter,
36    error::McpxError,
37    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
38    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
39};
40
41/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
42/// at the public API boundary, flattening the chain via the alternate
43/// formatter so callers see the full causal path.
44#[allow(
45    clippy::needless_pass_by_value,
46    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
47)]
48fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
49    McpxError::Startup(format!("{e:#}"))
50}
51
52/// Map a `std::io::Error` produced during server startup into a public
53/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
54/// `From` impl here because startup-phase IO errors (bind, listener) are
55/// semantically distinct from request-time IO errors and should surface
56/// the originating operation in the message.
57#[allow(
58    clippy::needless_pass_by_value,
59    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
60)]
61fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
62    McpxError::Startup(format!("{op}: {e}"))
63}
64
65/// Async readiness check callback for the `/readyz` endpoint.
66///
67/// Returns a JSON object with at least a `"ready"` boolean.
68/// When `ready` is false, the endpoint returns HTTP 503.
69pub type ReadinessCheck =
70    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
71
72/// Direct socket peer address of the current HTTP/TLS connection.
73///
74/// Inserted as a request extension into every request served by [`serve`] —
75/// on both the plain and the TLS listener — and extractable in any axum
76/// handler, including routes mounted via
77/// [`McpServerConfig::with_extra_router`] (which bypass auth/RBAC and
78/// therefore often need the peer address for their own protection, e.g.
79/// per-IP rate limiting).
80///
81/// The same address is also mirrored into
82/// [`axum::extract::ConnectInfo<SocketAddr>`] on the TLS listener, so
83/// third-party middleware that expects the stock axum extension (e.g.
84/// per-IP rate-limit key extractors) works unmodified under TLS.
85///
86/// # Semantics
87///
88/// - **Direct peer only.** This is the socket's remote address. Behind an
89///   L4/L7 proxy or load balancer it is the proxy's address; the framework
90///   performs **no** `X-Forwarded-For` / `Forwarded` interpretation.
91/// - **Available on HTTP and TLS** transports alike ([`serve`]).
92/// - **Absent under [`serve_stdio`]** — a stdio session has no network
93///   peer (stdio bypasses the HTTP stack entirely).
94/// - The separate Prometheus metrics listener (feature `metrics`) is a
95///   different router and does not carry this extension.
96///
97/// # Privacy
98///
99/// `PeerAddr` exposes raw peer network metadata. The framework deliberately
100/// never logs it on its own; whether to log or persist peer addresses is
101/// application policy.
102///
103/// # Example
104///
105/// ```no_run
106/// use axum::{Router, routing::get};
107/// use rmcp_server_kit::transport::{McpServerConfig, PeerAddr};
108///
109/// async fn whoami(peer: PeerAddr) -> String {
110///     peer.addr.ip().to_string()
111/// }
112///
113/// let _config = McpServerConfig::new("127.0.0.1:8443", "my-server", "1.0.0")
114///     .with_extra_router(Router::new().route("/whoami", get(whoami)));
115/// ```
116#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
117#[non_exhaustive]
118pub struct PeerAddr {
119    /// Direct socket peer of this connection.
120    pub addr: SocketAddr,
121}
122
123impl PeerAddr {
124    /// Construct a new [`PeerAddr`]. Framework-internal: downstream code
125    /// receives `PeerAddr` via request extensions and never constructs it.
126    #[must_use]
127    pub(crate) const fn new(addr: SocketAddr) -> Self {
128        Self { addr }
129    }
130}
131
132/// Extract [`PeerAddr`] from request extensions.
133///
134/// # Rejection
135///
136/// Responds `500 Internal Server Error` when the extension is missing.
137/// A missing `PeerAddr` means the handler is not running under [`serve`]
138/// (e.g. the router was mounted on a hand-rolled listener) — a wiring
139/// bug, not a client error.
140impl<S: Send + Sync> axum::extract::FromRequestParts<S> for PeerAddr {
141    type Rejection = (axum::http::StatusCode, &'static str);
142
143    async fn from_request_parts(
144        parts: &mut axum::http::request::Parts,
145        _state: &S,
146    ) -> Result<Self, Self::Rejection> {
147        parts.extensions.get::<Self>().copied().ok_or((
148            axum::http::StatusCode::INTERNAL_SERVER_ERROR,
149            "peer address unavailable: not running under rmcp-server-kit serve()",
150        ))
151    }
152}
153
154/// Resolved client IP of the current request.
155///
156/// Inserted as a request extension on every request served by [`serve`],
157/// right after [`PeerAddr`]. Equals the direct peer's IP unless
158/// **trusted-forwarder mode** is active
159/// ([`McpServerConfig::with_trusted_proxies`]) and the request arrived
160/// through a trusted proxy with a verifiable forwarding chain — in that
161/// case it is the rightmost-untrusted address from `X-Forwarded-For`
162/// (or RFC 7239 `Forwarded`, per
163/// [`McpServerConfig::with_forwarded_header`]).
164///
165/// All built-in per-IP rate limiters key by this value. [`PeerAddr`]
166/// keeps its direct-socket-peer contract unchanged; applications that
167/// need provenance can compare `ClientIp.ip` with `PeerAddr.addr.ip()`.
168///
169/// # Security
170///
171/// Resolution only ever activates when the **direct peer** is inside the
172/// operator's trusted-proxy CIDRs; every ambiguous chain (malformed or
173/// obfuscated entries, all-trusted chains, header bombs) falls back to
174/// the direct peer, never to a header value. The framework never logs
175/// this value outside rate-limit deny paths.
176#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
177#[non_exhaustive]
178pub struct ClientIp {
179    /// Resolved client IP (direct peer unless trusted-forwarder resolution applied).
180    pub ip: IpAddr,
181}
182
183impl ClientIp {
184    /// Construct a new [`ClientIp`]. Framework-internal: downstream code
185    /// receives `ClientIp` via request extensions and never constructs it.
186    #[must_use]
187    pub(crate) const fn new(ip: IpAddr) -> Self {
188        Self { ip }
189    }
190}
191
192/// Which forwarding header trusted-forwarder mode reads.
193///
194/// TOML wire values are kebab-case: `"x-forwarded-for"` (default when
195/// unset) and `"forwarded"`.
196#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Deserialize)]
197#[serde(rename_all = "kebab-case")]
198#[non_exhaustive]
199pub enum ForwardedHeaderMode {
200    /// De-facto standard `X-Forwarded-For` list (nginx, HAProxy, CDNs).
201    XForwardedFor,
202    /// RFC 7239 `Forwarded` header (`for=` parameters).
203    Forwarded,
204}
205
206/// Pre-parsed trusted-forwarder configuration captured by the
207/// peer-normalization middleware.
208struct ForwardResolver {
209    trusted: Vec<ipnet::IpNet>,
210    mode: ForwardedHeaderMode,
211}
212
213/// Per-header overrides for the OWASP security headers emitted by the
214/// global response middleware.
215///
216/// Each field follows a three-state semantic:
217///
218/// | Value         | Behaviour                                                |
219/// |---------------|----------------------------------------------------------|
220/// | `None`        | Use the built-in default (current behaviour).            |
221/// | `Some("")`    | **Omit** the header entirely from responses.             |
222/// | `Some(value)` | Emit `header: value`. Validated at config-load time.     |
223///
224/// All non-empty values are validated via
225/// [`axum::http::HeaderValue::from_str`] inside
226/// [`McpServerConfig::validate`]; invalid values fail fast before the
227/// server starts accepting traffic.
228///
229/// `Strict-Transport-Security` has an additional rule: the substring
230/// `preload` (case-insensitive) is rejected. Operators who want to
231/// commit to the HSTS preload list must do so via a future explicit
232/// builder method, not by smuggling it through this knob.
233#[derive(Debug, Clone, Default)]
234#[non_exhaustive]
235pub struct SecurityHeadersConfig {
236    /// Override for `X-Content-Type-Options`. Default: `nosniff`.
237    pub x_content_type_options: Option<String>,
238    /// Override for `X-Frame-Options`. Default: `deny`.
239    pub x_frame_options: Option<String>,
240    /// Override for `Cache-Control`. Default: `no-store, max-age=0`.
241    pub cache_control: Option<String>,
242    /// Override for `Referrer-Policy`. Default: `no-referrer`.
243    pub referrer_policy: Option<String>,
244    /// Override for `Cross-Origin-Opener-Policy`. Default: `same-origin`.
245    pub cross_origin_opener_policy: Option<String>,
246    /// Override for `Cross-Origin-Resource-Policy`. Default: `same-origin`.
247    pub cross_origin_resource_policy: Option<String>,
248    /// Override for `Cross-Origin-Embedder-Policy`. Default: `require-corp`.
249    pub cross_origin_embedder_policy: Option<String>,
250    /// Override for `Permissions-Policy`. Default:
251    /// `accelerometer=(), camera=(), geolocation=(), microphone=()`.
252    pub permissions_policy: Option<String>,
253    /// Override for `X-Permitted-Cross-Domain-Policies`. Default: `none`.
254    pub x_permitted_cross_domain_policies: Option<String>,
255    /// Override for `Content-Security-Policy`. Default:
256    /// `default-src 'none'; frame-ancestors 'none'`.
257    pub content_security_policy: Option<String>,
258    /// Override for `X-DNS-Prefetch-Control`. Default: `off`.
259    pub x_dns_prefetch_control: Option<String>,
260    /// Override for `Strict-Transport-Security`. Default (TLS only):
261    /// `max-age=63072000; includeSubDomains`. Only emitted when TLS is
262    /// active; the override is ignored on plaintext deployments. The
263    /// substring `preload` (any case) is rejected by the validator.
264    pub strict_transport_security: Option<String>,
265}
266
267/// Configuration for the MCP server.
268#[allow(
269    missing_debug_implementations,
270    reason = "contains callback/trait objects that don't impl Debug"
271)]
272#[allow(
273    clippy::struct_excessive_bools,
274    reason = "server configuration naturally has many boolean feature flags"
275)]
276#[non_exhaustive]
277pub struct McpServerConfig {
278    /// Socket address the MCP HTTP server binds to.
279    #[deprecated(
280        since = "0.13.0",
281        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
282    )]
283    pub bind_addr: String,
284    /// Server name advertised via MCP `initialize`.
285    #[deprecated(
286        since = "0.13.0",
287        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
288    )]
289    pub name: String,
290    /// Server version advertised via MCP `initialize`.
291    #[deprecated(
292        since = "0.13.0",
293        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
294    )]
295    pub version: String,
296    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
297    #[deprecated(
298        since = "0.13.0",
299        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
300    )]
301    pub tls_cert_path: Option<PathBuf>,
302    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
303    #[deprecated(
304        since = "0.13.0",
305        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
306    )]
307    pub tls_key_path: Option<PathBuf>,
308    /// Optional authentication config. When `Some` and `enabled`, auth
309    /// is enforced on `/mcp`. `/healthz` is always open.
310    #[deprecated(
311        since = "0.13.0",
312        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
313    )]
314    pub auth: Option<AuthConfig>,
315    /// Optional RBAC policy. When present and enabled, tool calls are
316    /// checked against the policy after authentication.
317    #[deprecated(
318        since = "0.13.0",
319        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
320    )]
321    pub rbac: Option<Arc<RbacPolicy>>,
322    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
323    /// When empty and `public_url` is set, the origin is auto-derived from
324    /// the public URL. When both are empty, only requests with no Origin
325    /// header are accepted.
326    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
327    #[deprecated(
328        since = "0.13.0",
329        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
330    )]
331    pub allowed_origins: Vec<String>,
332    /// Maximum tool invocations per source IP per minute.
333    /// When set, enforced on every `tools/call` request.
334    #[deprecated(
335        since = "0.13.0",
336        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
337    )]
338    pub tool_rate_limit: Option<u32>,
339    /// Burst capacity for the tool rate limiter: maximum `tools/call`
340    /// requests admitted back-to-back before the sustained
341    /// [`tool_rate_limit`](Self::tool_rate_limit) rate applies. `None`
342    /// (default) keeps governor's default of burst = rate. Requires
343    /// `tool_rate_limit` to be set; must be greater than zero.
344    #[deprecated(
345        since = "1.12.0",
346        note = "use McpServerConfig::with_tool_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
347    )]
348    pub tool_rate_limit_burst: Option<u32>,
349    /// Maximum requests per source IP per minute for routes merged via
350    /// [`with_extra_router`](Self::with_extra_router). Opt-in: `None`
351    /// (the default) installs no limiter. Startup-only (not
352    /// hot-reloadable via [`ReloadHandle`]).
353    ///
354    /// Keyed by the **direct socket peer** ([`PeerAddr`] semantics — no
355    /// `X-Forwarded-For` interpretation): behind a reverse proxy all
356    /// clients share the proxy's bucket, and IPv6 single-host address
357    /// rotation can evade per-IP keying. Treat this as an abuse speed
358    /// bump for unauthenticated application endpoints, not tenant
359    /// isolation. On limit: HTTP 429 with a plain-text body, matching
360    /// the tool/auth limiters.
361    #[deprecated(
362        since = "1.11.0",
363        note = "use McpServerConfig::with_extra_route_rate_limit(); direct field access will become pub(crate) in a future major release"
364    )]
365    pub extra_route_rate_limit: Option<u32>,
366    /// Burst capacity for the extra-route limiter: maximum requests
367    /// admitted back-to-back before the sustained
368    /// [`extra_route_rate_limit`](Self::extra_route_rate_limit) rate
369    /// applies. `None` (default) keeps governor's default of
370    /// burst = rate. Requires `extra_route_rate_limit` to be set; must
371    /// be greater than zero.
372    #[deprecated(
373        since = "1.12.0",
374        note = "use McpServerConfig::with_extra_route_rate_limit_burst(); direct field access will become pub(crate) in a future major release"
375    )]
376    pub extra_route_rate_limit_burst: Option<u32>,
377    /// Exact-match request paths exempt from the extra-route rate
378    /// limiter (e.g. `/.well-known/oauth-authorization-server`, which
379    /// MCP clients fetch on every connect — behind a shared egress the
380    /// limiter would otherwise 429 discovery). Matching is a **raw
381    /// exact string comparison** against `req.uri().path()`: no globs,
382    /// no prefixes, no normalization — trailing slashes,
383    /// percent-encoding, and dot-segments must match byte-for-byte.
384    /// Fail-closed: any path not listed stays rate-limited (a mismatch
385    /// can only mean "still limited", never "accidentally exempt").
386    /// Requires [`extra_route_rate_limit`](Self::extra_route_rate_limit);
387    /// each entry must be non-empty and start with `/` (validated).
388    /// Startup-only.
389    #[deprecated(
390        since = "1.14.0",
391        note = "use McpServerConfig::with_extra_route_rate_limit_exempt_paths(); direct field access will become pub(crate) in a future major release"
392    )]
393    pub extra_route_rate_limit_exempt_paths: Vec<String>,
394    /// Trusted reverse-proxy networks (CIDRs or bare IPs) for
395    /// **trusted-forwarder mode**. Empty (default) = mode off: every
396    /// limiter keys by the direct socket peer. Nonempty = requests whose
397    /// direct peer is inside one of these networks have their client IP
398    /// resolved from the forwarding header (rightmost-untrusted walk);
399    /// see [`ClientIp`]. Only enable when **all** ingress paths traverse
400    /// the listed proxies. Startup-only.
401    #[deprecated(
402        since = "1.13.0",
403        note = "use McpServerConfig::with_trusted_proxies(); direct field access will become pub(crate) in a future major release"
404    )]
405    pub trusted_proxies: Vec<String>,
406    /// Which forwarding header trusted-forwarder mode reads. `None`
407    /// (default) = `X-Forwarded-For`. Setting this requires
408    /// [`trusted_proxies`](Self::trusted_proxies) to be nonempty
409    /// (validated). Startup-only.
410    #[deprecated(
411        since = "1.13.0",
412        note = "use McpServerConfig::with_forwarded_header(); direct field access will become pub(crate) in a future major release"
413    )]
414    pub forwarded_header: Option<ForwardedHeaderMode>,
415    /// Optional readiness probe for `/readyz`.
416    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
417    #[deprecated(
418        since = "0.13.0",
419        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
420    )]
421    pub readiness_check: Option<ReadinessCheck>,
422    /// Maximum request body size in bytes. Default: 1 MiB.
423    /// Protects against oversized payloads causing OOM.
424    #[deprecated(
425        since = "0.13.0",
426        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
427    )]
428    pub max_request_body: usize,
429    /// Request processing timeout. Default: 120s.
430    /// Requests exceeding this duration receive 408 Request Timeout.
431    #[deprecated(
432        since = "0.13.0",
433        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
434    )]
435    pub request_timeout: Duration,
436    /// Graceful shutdown timeout. Default: 30s.
437    /// After the shutdown signal, in-flight requests have this long to finish.
438    #[deprecated(
439        since = "0.13.0",
440        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
441    )]
442    pub shutdown_timeout: Duration,
443    /// Idle timeout for MCP sessions. Sessions with no activity for this
444    /// duration are closed automatically. Default: 20 minutes.
445    #[deprecated(
446        since = "0.13.0",
447        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
448    )]
449    pub session_idle_timeout: Duration,
450    /// Interval for SSE keep-alive pings. Prevents proxies and load
451    /// balancers from killing idle connections. Default: 15 seconds.
452    #[deprecated(
453        since = "0.13.0",
454        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
455    )]
456    pub sse_keep_alive: Duration,
457    /// Callback invoked once the server is built, delivering a
458    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
459    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
460    #[deprecated(
461        since = "0.13.0",
462        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
463    )]
464    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
465    /// Additional application-specific routes merged into the top-level
466    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
467    /// so the application is responsible for its own auth on them.
468    /// Handlers can extract [`PeerAddr`] (or
469    /// [`axum::extract::ConnectInfo<SocketAddr>`] for third-party
470    /// middleware compatibility) regardless of whether TLS is enabled.
471    #[deprecated(
472        since = "0.13.0",
473        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
474    )]
475    pub extra_router: Option<axum::Router>,
476    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
477    /// When set, OAuth metadata endpoints advertise this URL instead of
478    /// the listen address. Required when binding `0.0.0.0` behind a
479    /// reverse proxy or inside a container.
480    #[deprecated(
481        since = "0.13.0",
482        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
483    )]
484    pub public_url: Option<String>,
485    /// Log inbound HTTP request headers at DEBUG level.
486    /// Sensitive values remain redacted.
487    #[deprecated(
488        since = "0.13.0",
489        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
490    )]
491    pub log_request_headers: bool,
492    /// Enable gzip/br response compression on MCP responses.
493    /// Defaults to `false` to preserve existing behaviour.
494    #[deprecated(
495        since = "0.13.0",
496        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
497    )]
498    pub compression_enabled: bool,
499    /// Minimum response body size (in bytes) before compression kicks in.
500    /// Only used when `compression_enabled` is true. Default: 1024.
501    #[deprecated(
502        since = "0.13.0",
503        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
504    )]
505    pub compression_min_size: u16,
506    /// Global cap on in-flight HTTP requests across the whole server.
507    /// When `Some`, requests over the cap receive 503 Service Unavailable
508    /// via `tower::load_shed`. Default: `None` (unlimited).
509    #[deprecated(
510        since = "0.13.0",
511        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
512    )]
513    pub max_concurrent_requests: Option<usize>,
514    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
515    /// configured and `enabled`. Default: `false`.
516    #[deprecated(
517        since = "0.13.0",
518        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
519    )]
520    pub admin_enabled: bool,
521    /// RBAC role required to access admin endpoints. Default: `"admin"`.
522    #[deprecated(
523        since = "0.13.0",
524        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
525    )]
526    pub admin_role: String,
527    /// Enable Prometheus metrics endpoint on a separate listener.
528    /// Requires the `metrics` crate feature.
529    #[cfg(feature = "metrics")]
530    #[deprecated(
531        since = "0.13.0",
532        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
533    )]
534    pub metrics_enabled: bool,
535    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
536    #[cfg(feature = "metrics")]
537    #[deprecated(
538        since = "0.13.0",
539        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
540    )]
541    pub metrics_bind: String,
542    /// Per-header overrides for the OWASP security headers emitted by
543    /// the global response middleware. See [`SecurityHeadersConfig`]
544    /// for the three-state semantic and validation rules.
545    #[deprecated(
546        since = "1.5.0",
547        note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
548    )]
549    pub security_headers: SecurityHeadersConfig,
550    /// Per-handshake deadline on the TLS accept path. Idle or slow-loris
551    /// connections are dropped once it elapses. Default: 10s.
552    ///
553    /// Startup-only: bound at listener construction, NOT hot-reloadable
554    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
555    #[deprecated(
556        since = "1.9.0",
557        note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
558    )]
559    pub tls_handshake_timeout: Duration,
560    /// Cap on concurrently in-flight TLS handshakes. At saturation the
561    /// acceptor stops pulling new connections from the kernel backlog
562    /// (backpressure) instead of accepting and dropping. Default: 256.
563    ///
564    /// Startup-only: bound at listener construction, NOT hot-reloadable
565    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
566    #[deprecated(
567        since = "1.9.0",
568        note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
569    )]
570    pub max_concurrent_tls_handshakes: usize,
571}
572
573/// Marker that wraps a value proven to satisfy its validation
574/// contract.
575///
576/// The only way to obtain `Validated<McpServerConfig>` is by calling
577/// [`McpServerConfig::validate`], which is the contract enforced at
578/// the type level by [`serve`] and [`serve_with_listener`]. The
579/// inner field is private, so downstream code cannot bypass
580/// validation by hand-constructing the wrapper.
581///
582/// Use [`Validated::as_inner`] for read-only borrowing. To mutate,
583/// recover the raw value with [`Validated::into_inner`] and
584/// re-validate.
585///
586/// # Example
587///
588/// ```no_run
589/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
590/// use rmcp::handler::server::ServerHandler;
591/// use rmcp::model::{ServerCapabilities, ServerInfo};
592///
593/// #[derive(Clone)]
594/// struct H;
595/// impl ServerHandler for H {
596///     fn get_info(&self) -> ServerInfo {
597///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
598///     }
599/// }
600///
601/// # async fn example() -> rmcp_server_kit::Result<()> {
602/// let config: Validated<McpServerConfig> =
603///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
604/// serve(config, || H).await
605/// # }
606/// ```
607///
608/// Forgetting `.validate()?` is a compile error:
609///
610/// ```compile_fail
611/// use rmcp_server_kit::transport::{McpServerConfig, serve};
612/// use rmcp::handler::server::ServerHandler;
613/// use rmcp::model::{ServerCapabilities, ServerInfo};
614///
615/// #[derive(Clone)]
616/// struct H;
617/// impl ServerHandler for H {
618///     fn get_info(&self) -> ServerInfo {
619///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
620///     }
621/// }
622///
623/// # async fn example() -> rmcp_server_kit::Result<()> {
624/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
625/// // Missing `.validate()?` -> mismatched types: expected
626/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
627/// serve(config, || H).await
628/// # }
629/// ```
630#[allow(
631    missing_debug_implementations,
632    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
633)]
634pub struct Validated<T>(T);
635
636impl<T> std::fmt::Debug for Validated<T> {
637    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638        f.debug_struct("Validated").finish_non_exhaustive()
639    }
640}
641
642impl<T> Validated<T> {
643    /// Borrow the inner value.
644    #[must_use]
645    pub fn as_inner(&self) -> &T {
646        &self.0
647    }
648
649    /// Recover the raw value, discarding the validation proof.
650    ///
651    /// Re-validate before re-using the value with [`serve`] or
652    /// [`serve_with_listener`].
653    #[must_use]
654    pub fn into_inner(self) -> T {
655        self.0
656    }
657}
658
659#[allow(
660    deprecated,
661    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
662)]
663impl McpServerConfig {
664    /// Create a new server configuration with the given bind address,
665    /// server name, and version. All other fields use safe defaults.
666    ///
667    /// Use the chainable `with_*` / `enable_*` builder methods to
668    /// customize. Call [`McpServerConfig::validate`] to obtain a
669    /// [`Validated<McpServerConfig>`] proof token, which is required by
670    /// [`serve`] and [`serve_with_listener`].
671    #[must_use]
672    pub fn new(
673        bind_addr: impl Into<String>,
674        name: impl Into<String>,
675        version: impl Into<String>,
676    ) -> Self {
677        Self {
678            bind_addr: bind_addr.into(),
679            name: name.into(),
680            version: version.into(),
681            tls_cert_path: None,
682            tls_key_path: None,
683            auth: None,
684            rbac: None,
685            allowed_origins: Vec::new(),
686            tool_rate_limit: None,
687            readiness_check: None,
688            max_request_body: 1024 * 1024,
689            request_timeout: Duration::from_mins(2),
690            shutdown_timeout: Duration::from_secs(30),
691            session_idle_timeout: Duration::from_mins(20),
692            sse_keep_alive: Duration::from_secs(15),
693            on_reload_ready: None,
694            extra_router: None,
695            public_url: None,
696            log_request_headers: false,
697            compression_enabled: false,
698            compression_min_size: 1024,
699            max_concurrent_requests: None,
700            admin_enabled: false,
701            admin_role: "admin".to_owned(),
702            #[cfg(feature = "metrics")]
703            metrics_enabled: false,
704            #[cfg(feature = "metrics")]
705            metrics_bind: "127.0.0.1:9090".into(),
706            security_headers: SecurityHeadersConfig::default(),
707            tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
708            max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
709            extra_route_rate_limit: None,
710            tool_rate_limit_burst: None,
711            extra_route_rate_limit_burst: None,
712            extra_route_rate_limit_exempt_paths: Vec::new(),
713            trusted_proxies: Vec::new(),
714            forwarded_header: None,
715        }
716    }
717
718    // ---------------------------------------------------------------
719    // Builder methods (fluent, consume + return self).
720    //
721    // Each method is `#[must_use]` because dropping the returned
722    // `McpServerConfig` discards the configuration change.
723    // ---------------------------------------------------------------
724
725    /// Attach an authentication configuration. Required for
726    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
727    #[must_use]
728    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
729        self.auth = Some(auth);
730        self
731    }
732
733    /// Override one or more of the OWASP security headers emitted on
734    /// every response. See [`SecurityHeadersConfig`] for the three-state
735    /// semantic (`None` = default, `Some("")` = omit, `Some(v)` =
736    /// override). Values are validated by [`Self::validate`].
737    #[must_use]
738    pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
739        self.security_headers = headers;
740        self
741    }
742
743    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
744    /// final port is only known after pre-binding an ephemeral listener
745    /// (tests, dynamic-port deployments).
746    #[must_use]
747    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
748        self.bind_addr = addr.into();
749        self
750    }
751
752    /// Attach an RBAC policy. Tool calls are checked against the policy
753    /// after authentication.
754    #[must_use]
755    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
756        self.rbac = Some(rbac);
757        self
758    }
759
760    /// Configure TLS by providing the certificate and private key paths
761    /// (PEM). Both must be readable at startup. Without this call, the
762    /// server runs plain HTTP.
763    #[must_use]
764    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
765        self.tls_cert_path = Some(cert_path.into());
766        self.tls_key_path = Some(key_path.into());
767        self
768    }
769
770    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
771    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
772    /// a container so OAuth metadata and auto-derived origins resolve correctly.
773    #[must_use]
774    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
775        self.public_url = Some(url.into());
776        self
777    }
778
779    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
780    /// When empty and [`with_public_url`](Self::with_public_url) is set,
781    /// the origin is auto-derived.
782    #[must_use]
783    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
784    where
785        I: IntoIterator<Item = S>,
786        S: Into<String>,
787    {
788        self.allowed_origins = origins.into_iter().map(Into::into).collect();
789        self
790    }
791
792    /// Merge an additional axum router at the top level. Routes added
793    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
794    /// for its own protection.
795    ///
796    /// To support that protection (e.g. per-IP rate limiting on
797    /// unauthenticated endpoints), every request served by [`serve`]
798    /// carries the client peer address regardless of whether TLS is
799    /// enabled: extract the framework-owned [`PeerAddr`] in your
800    /// handlers, or rely on [`axum::extract::ConnectInfo<SocketAddr>`]
801    /// for stock third-party middleware (e.g. per-IP rate-limit key
802    /// extractors). Neither extension exists under [`serve_stdio`],
803    /// which has no network peer.
804    #[must_use]
805    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
806        self.extra_router = Some(router);
807        self
808    }
809
810    /// Install an async readiness probe for `/readyz`. Without this call,
811    /// `/readyz` mirrors `/healthz` (always 200 OK).
812    #[must_use]
813    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
814        self.readiness_check = Some(check);
815        self
816    }
817
818    /// Override the maximum request body (bytes). Must be `> 0`.
819    /// Default: 1 MiB.
820    #[must_use]
821    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
822        self.max_request_body = bytes;
823        self
824    }
825
826    /// Override the per-request processing timeout. Default: 2 minutes.
827    #[must_use]
828    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
829        self.request_timeout = timeout;
830        self
831    }
832
833    /// Override the graceful shutdown grace period. Default: 30 seconds.
834    #[must_use]
835    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
836        self.shutdown_timeout = timeout;
837        self
838    }
839
840    /// Override the MCP session idle timeout. Default: 20 minutes.
841    #[must_use]
842    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
843        self.session_idle_timeout = timeout;
844        self
845    }
846
847    /// Override the SSE keep-alive interval. Default: 15 seconds.
848    #[must_use]
849    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
850        self.sse_keep_alive = interval;
851        self
852    }
853
854    /// Cap the global number of in-flight HTTP requests via
855    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
856    /// Default: unlimited.
857    #[must_use]
858    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
859        self.max_concurrent_requests = Some(limit);
860        self
861    }
862
863    /// Override the per-handshake deadline on the TLS accept path.
864    /// Idle or slow-loris connections are dropped once it elapses.
865    /// Default: 10s. Must be greater than zero.
866    ///
867    /// Startup-only: the value is bound at listener construction and is
868    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
869    /// is configured via [`Self::with_tls`].
870    #[must_use]
871    pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
872        self.tls_handshake_timeout = timeout;
873        self
874    }
875
876    /// Override the cap on concurrently in-flight TLS handshakes. At
877    /// saturation the acceptor stops pulling new connections from the
878    /// kernel backlog (backpressure) instead of accepting and dropping.
879    /// Default: 256. Must be greater than zero.
880    ///
881    /// Startup-only: the value is bound at listener construction and is
882    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
883    /// is configured via [`Self::with_tls`].
884    #[must_use]
885    pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
886        self.max_concurrent_tls_handshakes = limit;
887        self
888    }
889
890    /// Cap tool invocations per source IP per minute. Enforced on every
891    /// `tools/call` request.
892    #[must_use]
893    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
894        self.tool_rate_limit = Some(per_minute);
895        self
896    }
897
898    /// Cap requests per source IP per minute on routes merged via
899    /// [`with_extra_router`](Self::with_extra_router) — the natural
900    /// protection for unauthenticated application endpoints (OAuth
901    /// callbacks, registration, …) that bypass auth/RBAC by design.
902    ///
903    /// Must be greater than zero (validated by
904    /// [`validate`](Self::validate)). Startup-only. See the
905    /// `extra_route_rate_limit` field docs for keying semantics and
906    /// caveats (direct peer only, IPv6 rotation, proxy collapse,
907    /// bounded-memory shared-fate under key spray).
908    #[must_use]
909    pub fn with_extra_route_rate_limit(mut self, per_minute: u32) -> Self {
910        self.extra_route_rate_limit = Some(per_minute);
911        self
912    }
913
914    /// Set the burst capacity for the tool rate limiter (bucket size;
915    /// the sustained rate stays [`with_tool_rate_limit`](Self::with_tool_rate_limit)).
916    /// Requires the tool rate limit to be set; must be greater than zero
917    /// (both validated by [`validate`](Self::validate)).
918    #[must_use]
919    pub fn with_tool_rate_limit_burst(mut self, burst: u32) -> Self {
920        self.tool_rate_limit_burst = Some(burst);
921        self
922    }
923
924    /// Set the burst capacity for the extra-route rate limiter (bucket
925    /// size; the sustained rate stays
926    /// [`with_extra_route_rate_limit`](Self::with_extra_route_rate_limit)).
927    /// Requires the extra-route rate limit to be set; must be greater
928    /// than zero (both validated by [`validate`](Self::validate)).
929    #[must_use]
930    pub fn with_extra_route_rate_limit_burst(mut self, burst: u32) -> Self {
931        self.extra_route_rate_limit_burst = Some(burst);
932        self
933    }
934
935    /// Exempt specific request paths from the extra-route rate limiter.
936    ///
937    /// Matching is a **raw exact string comparison** against
938    /// `req.uri().path()` — no globs, no prefixes, no normalization
939    /// (trailing slashes, percent-encoding, and dot-segments must match
940    /// byte-for-byte). The check is fail-closed: anything not listed
941    /// stays rate-limited, so a mismatch can only keep a request
942    /// limited, never accidentally exempt it. The exemption is checked
943    /// before key extraction, so exempt requests consume no limiter
944    /// budget and never appear in deny telemetry.
945    ///
946    /// Typical use: the RFC 8414 authorization-server metadata document
947    /// (`/.well-known/oauth-authorization-server`), fetched by MCP
948    /// clients on every connect.
949    ///
950    /// Requires the extra-route rate limit to be set
951    /// ([`with_extra_route_rate_limit`](Self::with_extra_route_rate_limit));
952    /// each entry must be non-empty and start with `/` (both validated
953    /// by [`validate`](Self::validate)). Startup-only.
954    #[must_use]
955    pub fn with_extra_route_rate_limit_exempt_paths<I, S>(mut self, paths: I) -> Self
956    where
957        I: IntoIterator<Item = S>,
958        S: Into<String>,
959    {
960        self.extra_route_rate_limit_exempt_paths = paths.into_iter().map(Into::into).collect();
961        self
962    }
963
964    /// Enable **trusted-forwarder mode**: requests whose direct peer is
965    /// inside one of these networks (CIDRs or bare IPs) have their
966    /// client IP resolved from the forwarding header via the
967    /// rightmost-untrusted walk; all per-IP rate limiters then key by
968    /// the resolved [`ClientIp`]. Headers from peers outside these
969    /// networks are ignored entirely.
970    ///
971    /// Only enable when **all** ingress paths traverse the listed
972    /// proxies — otherwise direct clients keep their own buckets and
973    /// proxied clients collapse into the proxy's. Entries are validated
974    /// at [`validate`](Self::validate) time. Startup-only.
975    #[must_use]
976    pub fn with_trusted_proxies<I, S>(mut self, proxies: I) -> Self
977    where
978        I: IntoIterator<Item = S>,
979        S: Into<String>,
980    {
981        self.trusted_proxies = proxies.into_iter().map(Into::into).collect();
982        self
983    }
984
985    /// Select which forwarding header trusted-forwarder mode reads
986    /// (default: `X-Forwarded-For`). Requires
987    /// [`with_trusted_proxies`](Self::with_trusted_proxies) to be set
988    /// (validated).
989    #[must_use]
990    pub fn with_forwarded_header(mut self, mode: ForwardedHeaderMode) -> Self {
991        self.forwarded_header = Some(mode);
992        self
993    }
994
995    /// Register a callback that receives the [`ReloadHandle`] after the
996    /// server is built. Use it to wire SIGHUP-style hot reloads of API
997    /// keys and RBAC policy.
998    #[must_use]
999    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
1000    where
1001        F: FnOnce(ReloadHandle) + Send + 'static,
1002    {
1003        self.on_reload_ready = Some(Box::new(callback));
1004        self
1005    }
1006
1007    /// Enable gzip/brotli response compression on MCP responses.
1008    /// `min_size` is the smallest body size (bytes) eligible for
1009    /// compression. Default min size: 1024.
1010    #[must_use]
1011    pub fn enable_compression(mut self, min_size: u16) -> Self {
1012        self.compression_enabled = true;
1013        self.compression_min_size = min_size;
1014        self
1015    }
1016
1017    /// Enable `/admin/*` diagnostic endpoints. Requires
1018    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
1019    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
1020    /// role gate (default: `"admin"`).
1021    #[must_use]
1022    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
1023        self.admin_enabled = true;
1024        self.admin_role = role.into();
1025        self
1026    }
1027
1028    /// Log inbound HTTP request headers at DEBUG level. Sensitive
1029    /// values remain redacted by the logging layer.
1030    #[must_use]
1031    pub fn enable_request_header_logging(mut self) -> Self {
1032        self.log_request_headers = true;
1033        self
1034    }
1035
1036    /// Enable the Prometheus metrics listener on `bind` (e.g.
1037    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
1038    #[cfg(feature = "metrics")]
1039    #[must_use]
1040    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
1041        self.metrics_enabled = true;
1042        self.metrics_bind = bind.into();
1043        self
1044    }
1045
1046    /// Validate the configuration and consume `self`, returning a
1047    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
1048    /// and [`serve_with_listener`]. This is the only way to construct
1049    /// `Validated<McpServerConfig>`, so the type system guarantees
1050    /// validation has run before the server starts.
1051    ///
1052    /// Checks:
1053    ///
1054    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
1055    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
1056    ///    be unset.
1057    /// 3. `bind_addr` must parse as a [`SocketAddr`].
1058    /// 4. `public_url`, when set, must start with `http://` or `https://`.
1059    /// 5. Each entry in `allowed_origins` must start with `http://` or
1060    ///    `https://`.
1061    /// 6. `max_request_body` must be greater than zero.
1062    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
1063    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
1064    ///    `proxy.token_url`, `proxy.introspection_url`,
1065    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
1066    ///    and use the `https` scheme. Set
1067    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
1068    ///    targets (strongly discouraged in production - see the field-level
1069    ///    docs for the threat model).
1070    ///
1071    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
1072    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
1073    ///
1074    /// # Errors
1075    ///
1076    /// Returns [`McpxError::Config`] with a human-readable message on
1077    /// the first validation failure.
1078    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
1079        self.check()?;
1080        Ok(Validated(self))
1081    }
1082
1083    /// Validate the burst knobs: every burst must be greater than zero
1084    /// when set, and the two top-level bursts require their base limiter
1085    /// to be configured. The auth bursts
1086    /// (`RateLimitConfig::{burst, pre_auth_burst}`) have no orphan rule:
1087    /// their base rates always resolve (`max_attempts_per_minute` is
1088    /// mandatory; the pre-auth base derives from it when unset).
1089    fn check_burst_knobs(&self) -> Result<(), McpxError> {
1090        if self.tool_rate_limit_burst == Some(0) {
1091            return Err(McpxError::Config(
1092                "tool_rate_limit_burst must be greater than zero".into(),
1093            ));
1094        }
1095        if self.extra_route_rate_limit_burst == Some(0) {
1096            return Err(McpxError::Config(
1097                "extra_route_rate_limit_burst must be greater than zero".into(),
1098            ));
1099        }
1100        if self.tool_rate_limit_burst.is_some() && self.tool_rate_limit.is_none() {
1101            return Err(McpxError::Config(
1102                "tool_rate_limit_burst requires tool_rate_limit to be set".into(),
1103            ));
1104        }
1105        if self.extra_route_rate_limit_burst.is_some() && self.extra_route_rate_limit.is_none() {
1106            return Err(McpxError::Config(
1107                "extra_route_rate_limit_burst requires extra_route_rate_limit to be set".into(),
1108            ));
1109        }
1110        if !self.extra_route_rate_limit_exempt_paths.is_empty()
1111            && self.extra_route_rate_limit.is_none()
1112        {
1113            return Err(McpxError::Config(
1114                "extra_route_rate_limit_exempt_paths requires extra_route_rate_limit to be set"
1115                    .into(),
1116            ));
1117        }
1118        for path in &self.extra_route_rate_limit_exempt_paths {
1119            if path.is_empty() || !path.starts_with('/') {
1120                return Err(McpxError::Config(format!(
1121                    "extra_route_rate_limit_exempt_paths entries must be non-empty and start with '/': {path:?}"
1122                )));
1123            }
1124        }
1125        if let Some(rl) = self.auth.as_ref().and_then(|a| a.rate_limit.as_ref()) {
1126            if rl.burst == Some(0) {
1127                return Err(McpxError::Config(
1128                    "auth rate_limit.burst must be greater than zero".into(),
1129                ));
1130            }
1131            if rl.pre_auth_burst == Some(0) {
1132                return Err(McpxError::Config(
1133                    "auth rate_limit.pre_auth_burst must be greater than zero".into(),
1134                ));
1135            }
1136        }
1137        Ok(())
1138    }
1139
1140    /// Validate the trusted-forwarder knobs: every `trusted_proxies`
1141    /// entry must parse as a CIDR (`ipnet::IpNet`) or a bare IP
1142    /// (normalized to a host network), and `forwarded_header` requires a
1143    /// nonempty proxy list (fail-fast over a silent no-op).
1144    fn check_trusted_forwarder(&self) -> Result<(), McpxError> {
1145        for entry in &self.trusted_proxies {
1146            if parse_proxy_net(entry).is_none() {
1147                return Err(McpxError::Config(format!(
1148                    "trusted_proxies entry {entry:?} is neither a CIDR nor an IP address"
1149                )));
1150            }
1151        }
1152        if self.forwarded_header.is_some() && self.trusted_proxies.is_empty() {
1153            return Err(McpxError::Config(
1154                "forwarded_header requires trusted_proxies to be nonempty".into(),
1155            ));
1156        }
1157        Ok(())
1158    }
1159
1160    /// Run the validation checks without consuming `self`. Used by
1161    /// internal call sites (e.g. tests) that need to inspect a config
1162    /// without taking ownership.
1163    fn check(&self) -> Result<(), McpxError> {
1164        // 1. admin <-> auth dependency. Mirrors the runtime check in
1165        //    `build_app_router`: admin endpoints require an auth state,
1166        //    which is built only when `auth` is `Some` *and* `enabled`.
1167        if self.admin_enabled {
1168            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
1169            if !auth_enabled {
1170                return Err(McpxError::Config(
1171                    "admin_enabled=true requires auth to be configured and enabled".into(),
1172                ));
1173            }
1174        }
1175
1176        // 2. TLS cert / key must be paired
1177        match (&self.tls_cert_path, &self.tls_key_path) {
1178            (Some(_), None) => {
1179                return Err(McpxError::Config(
1180                    "tls_cert_path is set but tls_key_path is missing".into(),
1181                ));
1182            }
1183            (None, Some(_)) => {
1184                return Err(McpxError::Config(
1185                    "tls_key_path is set but tls_cert_path is missing".into(),
1186                ));
1187            }
1188            _ => {}
1189        }
1190
1191        // 3. bind_addr parses
1192        if self.bind_addr.parse::<SocketAddr>().is_err() {
1193            return Err(McpxError::Config(format!(
1194                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
1195                self.bind_addr
1196            )));
1197        }
1198
1199        // 4. public_url scheme
1200        if let Some(ref url) = self.public_url
1201            && !(url.starts_with("http://") || url.starts_with("https://"))
1202        {
1203            return Err(McpxError::Config(format!(
1204                "public_url {url:?} must start with http:// or https://"
1205            )));
1206        }
1207
1208        // 5. allowed_origins scheme
1209        for origin in &self.allowed_origins {
1210            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
1211                return Err(McpxError::Config(format!(
1212                    "allowed_origins entry {origin:?} must start with http:// or https://"
1213                )));
1214            }
1215        }
1216
1217        // 6. max_request_body > 0
1218        if self.max_request_body == 0 {
1219            return Err(McpxError::Config(
1220                "max_request_body must be greater than zero".into(),
1221            ));
1222        }
1223
1224        // 6b. extra_route_rate_limit, when set, must be > 0. Unlike the
1225        // legacy tool_rate_limit (which clamps 0 to its default at
1226        // construction), new knobs fail fast on nonsensical values.
1227        if self.extra_route_rate_limit == Some(0) {
1228            return Err(McpxError::Config(
1229                "extra_route_rate_limit must be greater than zero".into(),
1230            ));
1231        }
1232
1233        // 6c. Burst knobs (extracted helper).
1234        self.check_burst_knobs()?;
1235
1236        // 6d. Trusted-forwarder knobs (extracted helper).
1237        self.check_trusted_forwarder()?;
1238
1239        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
1240        #[cfg(feature = "oauth")]
1241        if let Some(auth_cfg) = &self.auth
1242            && let Some(oauth_cfg) = &auth_cfg.oauth
1243        {
1244            oauth_cfg.validate()?;
1245        }
1246
1247        // 8. Security-header overrides parse as valid HTTP header values,
1248        //    and HSTS does not smuggle in a `preload` directive.
1249        validate_security_headers(&self.security_headers)?;
1250
1251        // 9. max_concurrent_requests must be > 0 when set. Zero would
1252        //    deadlock the global concurrency limiter and reject every
1253        //    request. Mirrors the TOML-side check in `src/config.rs`.
1254        if self.max_concurrent_requests == Some(0) {
1255            return Err(McpxError::Config(
1256                "max_concurrent_requests must be greater than zero when set".into(),
1257            ));
1258        }
1259
1260        // 10. Auth rate-limit `max_tracked_keys` must be > 0. A zero cap
1261        //     would force `BoundedKeyedLimiter` to evict on every insert
1262        //     and effectively disable rate limiting.
1263        if let Some(auth_cfg) = &self.auth
1264            && let Some(rl) = &auth_cfg.rate_limit
1265            && rl.max_tracked_keys == 0
1266        {
1267            return Err(McpxError::Config(
1268                "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
1269            ));
1270        }
1271
1272        // 11. tls_handshake_timeout must be > 0. A zero deadline would
1273        //     reap every handshake before it could complete, rejecting
1274        //     all TLS connections. Mirrors the TOML-side check in
1275        //     `src/config.rs`.
1276        if self.tls_handshake_timeout == Duration::ZERO {
1277            return Err(McpxError::Config(
1278                "tls_handshake_timeout must be greater than zero".into(),
1279            ));
1280        }
1281
1282        // 12. max_concurrent_tls_handshakes must be > 0. A zero-permit
1283        //     semaphore would never admit a handshake, deadlocking the
1284        //     TLS accept path. Mirrors the TOML-side check in
1285        //     `src/config.rs`.
1286        if self.max_concurrent_tls_handshakes == 0 {
1287            return Err(McpxError::Config(
1288                "max_concurrent_tls_handshakes must be greater than zero".into(),
1289            ));
1290        }
1291
1292        Ok(())
1293    }
1294}
1295
1296/// Handle for hot-reloading server configuration without restart.
1297///
1298/// Obtained via [`McpServerConfig::on_reload_ready`].
1299/// All swap operations are lock-free and wait-free -- in-flight requests
1300/// finish with the old values while new requests see the update immediately.
1301#[allow(
1302    missing_debug_implementations,
1303    reason = "contains Arc<AuthState> with non-Debug fields"
1304)]
1305pub struct ReloadHandle {
1306    auth: Option<Arc<AuthState>>,
1307    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
1308    crl_set: Option<Arc<CrlSet>>,
1309}
1310
1311impl ReloadHandle {
1312    /// Atomically replace the API key list used by the auth middleware.
1313    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
1314        if let Some(ref auth) = self.auth {
1315            auth.reload_keys(keys);
1316        }
1317    }
1318
1319    /// Atomically replace the RBAC policy used by the RBAC middleware.
1320    pub fn reload_rbac(&self, policy: RbacPolicy) {
1321        if let Some(ref rbac) = self.rbac {
1322            rbac.store(Arc::new(policy));
1323            tracing::info!("RBAC policy reloaded");
1324        }
1325    }
1326
1327    /// Force an immediate refresh of all cached mTLS CRLs.
1328    ///
1329    /// # Errors
1330    ///
1331    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
1332    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
1333        let Some(ref crl_set) = self.crl_set else {
1334            return Err(McpxError::Config(
1335                "CRL refresh requested but mTLS CRL support is not configured".into(),
1336            ));
1337        };
1338
1339        crl_set.force_refresh().await
1340    }
1341}
1342
1343/// Generic MCP HTTP server.
1344///
1345/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
1346/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
1347/// with TLS (rustls). Optionally supports mTLS client certificate auth.
1348///
1349/// # Errors
1350///
1351/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
1352/// or the server fails.
1353// NOTE: cognitive complexity reduced from 111/25 to 83/25 by
1354// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
1355// Remaining flow is a linear router builder: middleware layering, feature-
1356// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
1357// would require threading many `&mut Router` helpers and hurt readability
1358// of the layer order (which is security-relevant and must stay visible).
1359#[allow(
1360    clippy::too_many_lines,
1361    clippy::cognitive_complexity,
1362    reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
1363)]
1364/// Internal bundle of values produced by [`build_app_router`] and
1365/// consumed by [`serve`] / [`serve_with_listener`] when driving the
1366/// HTTP listener.
1367struct AppRunParams {
1368    /// TLS cert/key paths when TLS is configured.
1369    tls_paths: Option<(PathBuf, PathBuf)>,
1370    /// Per-handshake deadline on the TLS accept path.
1371    tls_handshake_timeout: Duration,
1372    /// Cap on concurrently in-flight TLS handshakes.
1373    max_concurrent_tls_handshakes: usize,
1374    /// mTLS configuration when mutual-TLS auth is enabled.
1375    mtls_config: Option<MtlsConfig>,
1376    /// Graceful shutdown drain window.
1377    shutdown_timeout: Duration,
1378    /// Shared auth state used by hot-reload callbacks.
1379    auth_state: Option<Arc<AuthState>>,
1380    /// Hot-reloadable RBAC state used by reload callbacks.
1381    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1382    /// Optional callback that receives the final [`ReloadHandle`].
1383    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1384    /// Server-internal cancellation token. Cancelled by [`run_server`]
1385    /// once the shutdown trigger fires (so rmcp's child token also
1386    /// fires, terminating in-flight MCP sessions).
1387    ct: CancellationToken,
1388    /// `"http"` or `"https"` -- used only for boot-time logging.
1389    scheme: &'static str,
1390    /// Server name -- used only for boot-time logging.
1391    name: String,
1392}
1393
1394/// Build the full application axum [`axum::Router`] (MCP route +
1395/// middleware stack + admin + OAuth + health endpoints + security
1396/// headers + CORS + compression + concurrency limit + origin check)
1397/// and the [`AppRunParams`] needed to drive it.
1398///
1399/// This is the shared core of [`serve`] and [`serve_with_listener`].
1400/// It performs *no* network I/O: callers are responsible for binding
1401/// (or accepting a pre-bound) [`TcpListener`] and invoking
1402/// [`run_server`].
1403#[allow(
1404    clippy::cognitive_complexity,
1405    reason = "router assembly is intrinsically sequential; splitting harms readability"
1406)]
1407#[allow(
1408    deprecated,
1409    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
1410)]
1411fn build_app_router<H, F>(
1412    mut config: McpServerConfig,
1413    handler_factory: F,
1414) -> anyhow::Result<(axum::Router, AppRunParams)>
1415where
1416    H: ServerHandler + 'static,
1417    F: Fn() -> H + Send + Sync + Clone + 'static,
1418{
1419    let ct = CancellationToken::new();
1420
1421    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
1422    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
1423
1424    let mcp_service = StreamableHttpService::new(
1425        move || Ok(handler_factory()),
1426        {
1427            let mut mgr = LocalSessionManager::default();
1428            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
1429            mgr.into()
1430        },
1431        StreamableHttpServerConfig::default()
1432            .with_allowed_hosts(allowed_hosts)
1433            .with_sse_keep_alive(Some(config.sse_keep_alive))
1434            .with_cancellation_token(ct.child_token()),
1435    );
1436
1437    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
1438    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1439
1440    // Build auth state eagerly when auth is configured so we can wire both
1441    // the auth middleware *and* the optional admin router against the same
1442    // state. The middleware itself is installed further down in layer order.
1443    let auth_state: Option<Arc<AuthState>> = match config.auth {
1444        Some(ref auth_config) if auth_config.enabled => {
1445            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1446            let pre_auth_limiter = auth_config
1447                .rate_limit
1448                .as_ref()
1449                .map(crate::auth::build_pre_auth_limiter);
1450
1451            #[cfg(feature = "oauth")]
1452            let jwks_cache = auth_config
1453                .oauth
1454                .as_ref()
1455                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1456                .transpose()
1457                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1458
1459            Some(Arc::new(AuthState {
1460                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1461                rate_limiter,
1462                pre_auth_limiter,
1463                #[cfg(feature = "oauth")]
1464                jwks_cache,
1465                seen_identities: crate::auth::SeenIdentitySet::new(),
1466                counters: crate::auth::AuthCounters::default(),
1467            }))
1468        }
1469        _ => None,
1470    };
1471
1472    // Build the RBAC policy swap early so the admin router and the later
1473    // RBAC middleware layer share the same hot-reloadable state.
1474    let rbac_swap = Arc::new(ArcSwap::new(
1475        config
1476            .rbac
1477            .clone()
1478            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1479    ));
1480
1481    // Optional /admin/* diagnostic routes. Merged BEFORE the
1482    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
1483    if config.admin_enabled {
1484        let Some(ref auth_state_ref) = auth_state else {
1485            return Err(anyhow::anyhow!(
1486                "admin_enabled=true requires auth to be configured and enabled"
1487            ));
1488        };
1489        let admin_state = crate::admin::AdminState {
1490            started_at: std::time::Instant::now(),
1491            name: config.name.clone(),
1492            version: config.version.clone(),
1493            auth: Some(Arc::clone(auth_state_ref)),
1494            rbac: Arc::clone(&rbac_swap),
1495        };
1496        let admin_cfg = crate::admin::AdminConfig {
1497            role: config.admin_role.clone(),
1498        };
1499        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1500        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1501    }
1502
1503    // ----- Middleware order (CRITICAL: read carefully) ------------------
1504    //
1505    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
1506    // added is the OUTERMOST (runs first on a request). To achieve a
1507    // request-time flow of:
1508    //
1509    //   body-limit -> timeout -> auth -> rbac -> handler
1510    //
1511    // we add layers in the REVERSE order:
1512    //
1513    //   1. RBAC               (innermost, runs last before handler)
1514    //   2. auth               (parses identity, sets extension for RBAC)
1515    //   3. timeout            (bounds total request time)
1516    //   4. body-limit         (outermost on /mcp; caps payload before
1517    //                          anything else reads/buffers it)
1518    //
1519    // Origin validation is installed on the OUTER router (after the
1520    // /mcp router is merged in), so it also protects /healthz, /readyz,
1521    // /version, and any OAuth proxy endpoints.
1522    //
1523    // Rationale:
1524    // - Body-limit must be outermost on /mcp so RBAC (which reads the
1525    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
1526    // - Auth must run before RBAC because RBAC consumes
1527    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
1528    //   policy.
1529    // - Origin runs before auth so we reject cross-origin requests
1530    //   without spending Argon2 cycles on unauthenticated callers.
1531
1532    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
1533    // Always installed: even when RBAC is disabled, tool rate limiting may
1534    // be active (MCP spec: servers MUST rate limit tool invocations).
1535    {
1536        let tool_limiter: Option<Arc<ToolRateLimiter>> = config
1537            .tool_rate_limit
1538            .map(|per_minute| build_tool_rate_limiter(per_minute, config.tool_rate_limit_burst));
1539
1540        if rbac_swap.load().is_enabled() {
1541            tracing::info!("RBAC enforcement enabled on /mcp");
1542        }
1543        if let Some(limit) = config.tool_rate_limit {
1544            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1545        }
1546
1547        let rbac_for_mw = Arc::clone(&rbac_swap);
1548        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1549            let p = rbac_for_mw.load_full();
1550            let tl = tool_limiter.clone();
1551            rbac_middleware(p, tl, req, next)
1552        }));
1553    }
1554
1555    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
1556    if let Some(ref auth_config) = config.auth
1557        && auth_config.enabled
1558    {
1559        let Some(ref state) = auth_state else {
1560            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1561        };
1562
1563        let methods: Vec<&str> = [
1564            auth_config.mtls.is_some().then_some("mTLS"),
1565            (!auth_config.api_keys.is_empty()).then_some("bearer"),
1566            #[cfg(feature = "oauth")]
1567            auth_config.oauth.is_some().then_some("oauth-jwt"),
1568        ]
1569        .into_iter()
1570        .flatten()
1571        .collect();
1572
1573        tracing::info!(
1574            methods = %methods.join(", "),
1575            api_keys = auth_config.api_keys.len(),
1576            "auth enabled on /mcp"
1577        );
1578
1579        let state_for_mw = Arc::clone(state);
1580        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1581            let s = Arc::clone(&state_for_mw);
1582            auth_middleware(s, req, next)
1583        }));
1584    }
1585
1586    // [3] Request timeout (returns 408 on expiry). Bounds total request
1587    // duration including auth + handler.
1588    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1589        axum::http::StatusCode::REQUEST_TIMEOUT,
1590        config.request_timeout,
1591    ));
1592
1593    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
1594    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
1595    // attempts to buffer or parse the body.
1596    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1597        config.max_request_body,
1598    ));
1599
1600    // Compute the effective allowed-origins list for the outer
1601    // origin-check layer (installed on the merged router below). When
1602    // `allowed_origins` is empty but `public_url` is set, auto-derive
1603    // the origin from the public URL so MCP clients (e.g. Claude Code)
1604    // that send `Origin: <server-url>` are accepted without explicit
1605    // config.
1606    let mut effective_origins = config.allowed_origins.clone();
1607    if effective_origins.is_empty()
1608        && let Some(ref url) = config.public_url
1609    {
1610        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1611        // Strip any path/query from the public URL. Offsets come from
1612        // `find`, so they are char-boundary-aligned; `get(..)` keeps that
1613        // machine-checked (a violation degrades to an empty slice).
1614        if let Some(scheme_end) = url.find("://") {
1615            let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1616            let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1617            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1618            let host = after_scheme.get(..host_end).unwrap_or_default();
1619            let origin = format!("{scheme_with_sep}{host}");
1620            tracing::info!(
1621                %origin,
1622                "auto-derived allowed origin from public_url"
1623            );
1624            effective_origins.push(origin);
1625        }
1626    }
1627    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1628    let cors_origins = Arc::clone(&allowed_origins);
1629    let log_request_headers = config.log_request_headers;
1630
1631    let readyz_route = if let Some(check) = config.readiness_check.take() {
1632        axum::routing::get(move || readyz(Arc::clone(&check)))
1633    } else {
1634        axum::routing::get(healthz)
1635    };
1636
1637    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1638    let mut router = axum::Router::new()
1639        .route("/healthz", axum::routing::get(healthz))
1640        .route("/readyz", readyz_route)
1641        .route(
1642            "/version",
1643            axum::routing::get({
1644                // Pre-serialize the version payload once at router-build
1645                // time. The handler then serves a cheap `Arc::clone` of the
1646                // immutable bytes per request, avoiding `serde_json::Value`
1647                // allocation + serialization on every `/version` hit.
1648                let payload_bytes: Arc<[u8]> =
1649                    serialize_version_payload(&config.name, &config.version);
1650                move || {
1651                    let p = Arc::clone(&payload_bytes);
1652                    async move {
1653                        (
1654                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1655                            p.to_vec(),
1656                        )
1657                    }
1658                }
1659            }),
1660        )
1661        .merge(mcp_router);
1662
1663    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1664    // When configured, wrap them — and only them — in the per-IP rate
1665    // limiter BEFORE merging: axum layers wrap only the routes already
1666    // present on the sub-router, so the limiter can never leak onto
1667    // `/mcp`, health, admin, or OAuth endpoints, while top-level layers
1668    // (origin check, peer-address normalization, ...) still run first.
1669    if let Some(extra) = config.extra_router.take() {
1670        let extra = match config.extra_route_rate_limit {
1671            Some(per_minute) => {
1672                let limiter =
1673                    build_extra_route_rate_limiter(per_minute, config.extra_route_rate_limit_burst);
1674                let exempt: Arc<std::collections::HashSet<String>> = Arc::new(
1675                    config
1676                        .extra_route_rate_limit_exempt_paths
1677                        .iter()
1678                        .cloned()
1679                        .collect(),
1680                );
1681                tracing::info!(
1682                    per_minute,
1683                    exempt_paths = exempt.len(),
1684                    "extra-route per-IP rate limit enabled"
1685                );
1686                extra.layer(axum::middleware::from_fn(move |req, next| {
1687                    let l = Arc::clone(&limiter);
1688                    let e = Arc::clone(&exempt);
1689                    extra_route_rate_limit_middleware(l, e, req, next)
1690                }))
1691            }
1692            None => extra,
1693        };
1694        router = router.merge(extra);
1695    }
1696
1697    // RFC 9728: Protected Resource Metadata endpoint.
1698    // When OAuth is configured, serve full metadata with authorization_servers.
1699    // Otherwise, serve a minimal document with just the resource URL and no
1700    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1701    // that the server exists but does NOT require OAuth authentication,
1702    // preventing them from gating the connection behind a broken auth flow.
1703    let server_url = if let Some(ref url) = config.public_url {
1704        url.trim_end_matches('/').to_owned()
1705    } else {
1706        let prm_scheme = if config.tls_cert_path.is_some() {
1707            "https"
1708        } else {
1709            "http"
1710        };
1711        format!("{prm_scheme}://{}", config.bind_addr)
1712    };
1713    let resource_url = format!("{server_url}/mcp");
1714
1715    #[cfg(feature = "oauth")]
1716    let prm_metadata = if let Some(ref auth_config) = config.auth
1717        && let Some(ref oauth_config) = auth_config.oauth
1718    {
1719        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1720    } else {
1721        serde_json::json!({ "resource": resource_url })
1722    };
1723    #[cfg(not(feature = "oauth"))]
1724    let prm_metadata = serde_json::json!({ "resource": resource_url });
1725
1726    router = router.route(
1727        "/.well-known/oauth-protected-resource",
1728        axum::routing::get(move || {
1729            let m = prm_metadata.clone();
1730            async move { axum::Json(m) }
1731        }),
1732    );
1733
1734    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1735    // /authorize, /token, /register, and authorization server metadata so
1736    // MCP clients can perform Authorization Code + PKCE against the upstream
1737    // IdP (e.g. Keycloak) transparently.
1738    #[cfg(feature = "oauth")]
1739    if let Some(ref auth_config) = config.auth
1740        && let Some(ref oauth_config) = auth_config.oauth
1741        && oauth_config.proxy.is_some()
1742    {
1743        router =
1744            install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1745    }
1746
1747    // OWASP security response headers (applied to all responses).
1748    // HSTS is conditional on TLS being configured.
1749    let is_tls = config.tls_cert_path.is_some();
1750    let security_headers_cfg = Arc::new(config.security_headers.clone());
1751    router = router.layer(axum::middleware::from_fn(move |req, next| {
1752        let cfg = Arc::clone(&security_headers_cfg);
1753        security_headers_middleware(is_tls, cfg, req, next)
1754    }));
1755
1756    // CORS preflight layer (required for browser-based MCP clients).
1757    // Uses the same effective origins as the origin check middleware
1758    // (including auto-derived origin from public_url).
1759    if !cors_origins.is_empty() {
1760        let cors = tower_http::cors::CorsLayer::new()
1761            .allow_origin(
1762                cors_origins
1763                    .iter()
1764                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1765                    .collect::<Vec<_>>(),
1766            )
1767            .allow_methods([
1768                axum::http::Method::GET,
1769                axum::http::Method::POST,
1770                axum::http::Method::OPTIONS,
1771            ])
1772            .allow_headers([
1773                axum::http::header::CONTENT_TYPE,
1774                axum::http::header::AUTHORIZATION,
1775            ]);
1776        router = router.layer(cors);
1777    }
1778
1779    // Optional response compression (gzip + brotli). Skips small bodies
1780    // to avoid overhead. Applied after CORS so preflight responses remain
1781    // uncompressed.
1782    if config.compression_enabled {
1783        use tower_http::compression::Predicate as _;
1784        let predicate = tower_http::compression::DefaultPredicate::new().and(
1785            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1786        );
1787        router = router.layer(
1788            tower_http::compression::CompressionLayer::new()
1789                .gzip(true)
1790                .br(true)
1791                .compress_when(predicate),
1792        );
1793        tracing::info!(
1794            min_size = config.compression_min_size,
1795            "response compression enabled (gzip, br)"
1796        );
1797    }
1798
1799    // Optional global concurrency cap. `load_shed` converts the
1800    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1801    if let Some(max) = config.max_concurrent_requests {
1802        let overload_handler = tower::ServiceBuilder::new()
1803            .layer(axum::error_handling::HandleErrorLayer::new(
1804                |_err: tower::BoxError| async {
1805                    (
1806                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1807                        axum::Json(serde_json::json!({
1808                            "error": "overloaded",
1809                            "error_description": "server is at capacity, retry later"
1810                        })),
1811                    )
1812                },
1813            ))
1814            .layer(tower::load_shed::LoadShedLayer::new())
1815            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1816        router = router.layer(overload_handler);
1817        tracing::info!(max, "global concurrency limit enabled");
1818    }
1819
1820    // JSON fallback for unmatched routes. Without this, axum returns
1821    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1822    // when they probe OAuth endpoints like /authorize or /token.
1823    router = router.fallback(|| async {
1824        (
1825            axum::http::StatusCode::NOT_FOUND,
1826            axum::Json(serde_json::json!({
1827                "error": "not_found",
1828                "error_description": "The requested endpoint does not exist"
1829            })),
1830        )
1831    });
1832
1833    // Prometheus metrics: recording middleware + separate listener.
1834    #[cfg(feature = "metrics")]
1835    if config.metrics_enabled {
1836        let metrics = Arc::new(
1837            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1838        );
1839        let m = Arc::clone(&metrics);
1840        router = router.layer(axum::middleware::from_fn(
1841            move |req: Request<Body>, next: Next| {
1842                let m = Arc::clone(&m);
1843                metrics_middleware(m, req, next)
1844            },
1845        ));
1846        let metrics_bind = config.metrics_bind.clone();
1847        let metrics_shutdown = ct.clone();
1848        tokio::spawn(async move {
1849            if let Err(e) =
1850                crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1851            {
1852                tracing::error!("metrics listener failed: {e}");
1853            }
1854        });
1855    }
1856
1857    // Peer-address normalization. Mirrors the TLS branch's peer address
1858    // into `ConnectInfo<SocketAddr>` and exposes the framework-owned
1859    // `PeerAddr` extension on both listener branches, so ALL routes on
1860    // the merged router (`/mcp`, `/healthz`, OAuth proxy endpoints,
1861    // admin endpoints, extra_router, ...) and all inner middleware see a
1862    // uniform peer-address contract regardless of TLS. Installed just
1863    // inside the origin check, which stays outermost by design.
1864    let forward_resolver: Option<Arc<ForwardResolver>> = if config.trusted_proxies.is_empty() {
1865        None
1866    } else {
1867        // Entries are guaranteed parseable by `check_trusted_forwarder`;
1868        // filter_map is defensive only.
1869        Some(Arc::new(ForwardResolver {
1870            trusted: config
1871                .trusted_proxies
1872                .iter()
1873                .filter_map(|entry| parse_proxy_net(entry))
1874                .collect(),
1875            mode: config
1876                .forwarded_header
1877                .unwrap_or(ForwardedHeaderMode::XForwardedFor),
1878        }))
1879    };
1880    if forward_resolver.is_some() {
1881        tracing::info!(
1882            proxies = config.trusted_proxies.len(),
1883            "trusted-forwarder mode enabled: limiters key by resolved client IP"
1884        );
1885    }
1886    router = router.layer(axum::middleware::from_fn(move |req, next| {
1887        let r = forward_resolver.clone();
1888        normalize_peer_addr_middleware(r, req, next)
1889    }));
1890
1891    // Origin validation layer (MCP spec: servers MUST validate the
1892    // Origin header to prevent DNS rebinding attacks). Installed as the
1893    // OUTERMOST layer on the OUTER router so it protects ALL routes
1894    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1895    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1896    // reject cross-origin attackers without spending Argon2 cycles.
1897    //
1898    // Origin-less requests (e.g. server-to-server probes, curl, native
1899    // MCP clients) are permitted; only requests with an Origin header
1900    // that does not match `effective_origins` are rejected.
1901    router = router.layer(axum::middleware::from_fn(move |req, next| {
1902        let origins = Arc::clone(&allowed_origins);
1903        origin_check_middleware(origins, log_request_headers, req, next)
1904    }));
1905
1906    let scheme = if config.tls_cert_path.is_some() {
1907        "https"
1908    } else {
1909        "http"
1910    };
1911
1912    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1913        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1914        _ => None,
1915    };
1916    let tls_handshake_timeout = config.tls_handshake_timeout;
1917    let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1918    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1919
1920    Ok((
1921        router,
1922        AppRunParams {
1923            tls_paths,
1924            tls_handshake_timeout,
1925            max_concurrent_tls_handshakes,
1926            mtls_config,
1927            shutdown_timeout: config.shutdown_timeout,
1928            auth_state,
1929            rbac_swap,
1930            on_reload_ready: config.on_reload_ready.take(),
1931            ct,
1932            scheme,
1933            name: config.name.clone(),
1934        },
1935    ))
1936}
1937
1938/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1939/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1940///
1941/// This is the standard entry point for production deployments. For
1942/// deterministic shutdown control (e.g. integration tests), see
1943/// [`serve_with_listener`].
1944///
1945/// The configuration must be validated first via
1946/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1947/// token. This typestate guarantees, at compile time, that the server
1948/// never starts with an invalid configuration.
1949///
1950/// # Errors
1951///
1952/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1953/// fails, or if the underlying axum server returns an error.
1954pub async fn serve<H, F>(
1955    config: Validated<McpServerConfig>,
1956    handler_factory: F,
1957) -> Result<(), McpxError>
1958where
1959    H: ServerHandler + 'static,
1960    F: Fn() -> H + Send + Sync + Clone + 'static,
1961{
1962    let config = config.into_inner();
1963    #[allow(
1964        deprecated,
1965        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1966    )]
1967    let bind_addr = config.bind_addr.clone();
1968    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1969
1970    let listener = TcpListener::bind(&bind_addr)
1971        .await
1972        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1973    log_listening(&params.name, params.scheme, &bind_addr);
1974
1975    run_server(
1976        router,
1977        listener,
1978        params.tls_paths,
1979        params.tls_handshake_timeout,
1980        params.max_concurrent_tls_handshakes,
1981        params.mtls_config,
1982        params.shutdown_timeout,
1983        params.auth_state,
1984        params.rbac_swap,
1985        params.on_reload_ready,
1986        params.ct,
1987    )
1988    .await
1989    .map_err(anyhow_to_startup)
1990}
1991
1992/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1993/// readiness signalling and external shutdown control.
1994///
1995/// This variant is intended for **deterministic integration tests** and
1996/// for embedders that need to bind the listening socket themselves
1997/// (e.g. systemd socket activation). Compared to [`serve`]:
1998///
1999/// * The caller passes a `TcpListener` that is already bound. This
2000///   eliminates the bind race in tests that previously required
2001///   poll-the-`/healthz`-loop start-up detection.
2002/// * `ready_tx`, when `Some`, receives the socket's
2003///   [`SocketAddr`] *after* the router is built and immediately before
2004///   the server starts accepting connections. Tests can `await` the
2005///   matching `oneshot::Receiver` to know exactly when it is safe to
2006///   issue requests.
2007/// * `shutdown`, when `Some`, gives the caller a
2008///   [`CancellationToken`] that triggers the same graceful-shutdown
2009///   path as a real OS signal. This avoids cross-platform issues with
2010///   sending real `SIGTERM` from tests on Windows.
2011///
2012/// All three optional parameters degrade gracefully: if `ready_tx` is
2013/// `None`, no signal is sent; if `shutdown` is `None`, the server only
2014/// stops on an OS signal (just like [`serve`]).
2015///
2016/// # Errors
2017///
2018/// Returns [`McpxError::Startup`] if router construction fails, if reading
2019/// the listener's `local_addr()` fails, or if the underlying axum
2020/// server returns an error.
2021pub async fn serve_with_listener<H, F>(
2022    listener: TcpListener,
2023    config: Validated<McpServerConfig>,
2024    handler_factory: F,
2025    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
2026    shutdown: Option<CancellationToken>,
2027) -> Result<(), McpxError>
2028where
2029    H: ServerHandler + 'static,
2030    F: Fn() -> H + Send + Sync + Clone + 'static,
2031{
2032    let config = config.into_inner();
2033    let local_addr = listener
2034        .local_addr()
2035        .map_err(|e| io_to_startup("listener.local_addr", e))?;
2036    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
2037
2038    log_listening(&params.name, params.scheme, &local_addr.to_string());
2039
2040    // Forward external shutdown into the server-internal cancellation
2041    // token so `run_server`'s shutdown trigger picks it up alongside
2042    // any real OS signal.
2043    if let Some(external) = shutdown {
2044        let internal = params.ct.clone();
2045        tokio::spawn(async move {
2046            external.cancelled().await;
2047            internal.cancel();
2048        });
2049    }
2050
2051    // Signal readiness *after* the router is fully built and external
2052    // shutdown is wired, but *before* run_server takes ownership of
2053    // the listener. The receiver can immediately issue requests.
2054    if let Some(tx) = ready_tx {
2055        // Receiver may have been dropped (test gave up). That's fine.
2056        let _ = tx.send(local_addr);
2057    }
2058
2059    run_server(
2060        router,
2061        listener,
2062        params.tls_paths,
2063        params.tls_handshake_timeout,
2064        params.max_concurrent_tls_handshakes,
2065        params.mtls_config,
2066        params.shutdown_timeout,
2067        params.auth_state,
2068        params.rbac_swap,
2069        params.on_reload_ready,
2070        params.ct,
2071    )
2072    .await
2073    .map_err(anyhow_to_startup)
2074}
2075
2076/// Emit the standard "listening on …" log lines used by both
2077/// [`serve`] and [`serve_with_listener`].
2078#[allow(
2079    clippy::cognitive_complexity,
2080    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
2081)]
2082fn log_listening(name: &str, scheme: &str, addr: &str) {
2083    tracing::info!("{name} listening on {addr}");
2084    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
2085    tracing::info!("  Health check: {scheme}://{addr}/healthz");
2086    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
2087}
2088
2089/// Drive the chosen axum server variant (TLS or plain) with a graceful
2090/// shutdown window. Consumes the router and listener.
2091///
2092/// # Shutdown semantics
2093///
2094/// A single shutdown trigger (the FIRST of: OS signal via
2095/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
2096///
2097/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
2098///    accepting new connections and waits for in-flight requests to
2099///    drain;
2100/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
2101///    drainage exceeds `shutdown_timeout`.
2102///
2103/// Previously this function awaited `shutdown_signal()` independently
2104/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
2105/// resolves once per future and consumes one signal, the force-exit
2106/// timer was tied to a SECOND signal (a second SIGTERM the operator
2107/// would never send). Under a single SIGTERM the graceful drain could
2108/// hang indefinitely. The current implementation derives both branches
2109/// from a single shared trigger so the timeout race is anchored to the
2110/// FIRST (and only) signal.
2111#[allow(
2112    clippy::too_many_arguments,
2113    clippy::cognitive_complexity,
2114    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
2115)]
2116async fn run_server(
2117    router: axum::Router,
2118    listener: TcpListener,
2119    tls_paths: Option<(PathBuf, PathBuf)>,
2120    tls_handshake_timeout: Duration,
2121    max_concurrent_tls_handshakes: usize,
2122    mtls_config: Option<MtlsConfig>,
2123    shutdown_timeout: Duration,
2124    auth_state: Option<Arc<AuthState>>,
2125    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
2126    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
2127    ct: CancellationToken,
2128) -> anyhow::Result<()> {
2129    // `shutdown_trigger` fires when the FIRST source resolves: either
2130    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
2131    // (which the test harness uses for deterministic shutdown).
2132    let shutdown_trigger = CancellationToken::new();
2133    {
2134        let trigger = shutdown_trigger.clone();
2135        let parent = ct.clone();
2136        tokio::spawn(async move {
2137            // cancel-safe: both arms (signal future, CancellationToken::cancelled)
2138            // are cancel-safe; the losing arm holds no state.
2139            tokio::select! {
2140                () = shutdown_signal() => {}
2141                () = parent.cancelled() => {}
2142            }
2143            trigger.cancel();
2144        });
2145    }
2146
2147    let graceful = {
2148        let trigger = shutdown_trigger.clone();
2149        let ct = ct.clone();
2150        async move {
2151            trigger.cancelled().await;
2152            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
2153            ct.cancel();
2154        }
2155    };
2156
2157    let force_exit_timer = {
2158        let trigger = shutdown_trigger.clone();
2159        async move {
2160            trigger.cancelled().await;
2161            tokio::time::sleep(shutdown_timeout).await;
2162        }
2163    };
2164
2165    if let Some((cert_path, key_path)) = tls_paths {
2166        let crl_set = if let Some(mtls) = mtls_config.as_ref()
2167            && mtls.crl_enabled
2168        {
2169            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
2170            let (crl_set, discover_rx) =
2171                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
2172                    .await
2173                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
2174            tokio::spawn(mtls_revocation::run_crl_refresher(
2175                Arc::clone(&crl_set),
2176                discover_rx,
2177                ct.clone(),
2178            ));
2179            Some(crl_set)
2180        } else {
2181            None
2182        };
2183
2184        if let Some(cb) = on_reload_ready.take() {
2185            cb(ReloadHandle {
2186                auth: auth_state.clone(),
2187                rbac: Some(Arc::clone(&rbac_swap)),
2188                crl_set: crl_set.clone(),
2189            });
2190        }
2191
2192        let tls_listener = TlsListener::new(
2193            listener,
2194            &cert_path,
2195            &key_path,
2196            mtls_config.as_ref(),
2197            crl_set,
2198            tls_handshake_timeout,
2199            max_concurrent_tls_handshakes,
2200        )?;
2201        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
2202        // cancel-safe: dropping the serve future on force-exit is intentional
2203        // forced-shutdown semantics; force_exit_timer is a Sleep chain.
2204        tokio::select! {
2205            result = axum::serve(tls_listener, make_svc)
2206                .with_graceful_shutdown(graceful) => { result?; }
2207            () = force_exit_timer => {
2208                tracing::warn!("shutdown timeout exceeded, forcing exit");
2209            }
2210        }
2211    } else {
2212        if let Some(cb) = on_reload_ready.take() {
2213            cb(ReloadHandle {
2214                auth: auth_state,
2215                rbac: Some(rbac_swap),
2216                crl_set: None,
2217            });
2218        }
2219
2220        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
2221        // cancel-safe: dropping the serve future on force-exit is intentional
2222        // forced-shutdown semantics; force_exit_timer is a Sleep chain.
2223        tokio::select! {
2224            result = axum::serve(listener, make_svc)
2225                .with_graceful_shutdown(graceful) => { result?; }
2226            () = force_exit_timer => {
2227                tracing::warn!("shutdown timeout exceeded, forcing exit");
2228            }
2229        }
2230    }
2231
2232    Ok(())
2233}
2234
2235/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
2236/// `/register`, and authorization server metadata) on `router`. The
2237/// caller must ensure `oauth_config.proxy` is `Some`.
2238///
2239/// # Errors
2240///
2241/// Returns [`McpxError::Startup`] if the shared
2242/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
2243#[cfg(feature = "oauth")]
2244fn install_oauth_proxy_routes(
2245    router: axum::Router,
2246    server_url: &str,
2247    oauth_config: &crate::oauth::OAuthConfig,
2248    auth_state: Option<&Arc<AuthState>>,
2249) -> Result<axum::Router, McpxError> {
2250    let Some(ref proxy) = oauth_config.proxy else {
2251        return Ok(router);
2252    };
2253
2254    // Single shared HTTP client for all proxy endpoints. Cloning is
2255    // cheap (refcounted) and shares the underlying connection pool.
2256    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
2257
2258    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
2259    let router = router.route(
2260        "/.well-known/oauth-authorization-server",
2261        axum::routing::get(move || {
2262            let m = asm.clone();
2263            async move { axum::Json(m) }
2264        }),
2265    );
2266
2267    let proxy_authorize = proxy.clone();
2268    let router = router.route(
2269        "/authorize",
2270        axum::routing::get(
2271            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
2272                let p = proxy_authorize.clone();
2273                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
2274            },
2275        ),
2276    );
2277
2278    let proxy_token = proxy.clone();
2279    let token_http = http.clone();
2280    let router = router.route(
2281        "/token",
2282        axum::routing::post(move |body: String| {
2283            let p = proxy_token.clone();
2284            let h = token_http.clone();
2285            async move { crate::oauth::handle_token(&h, &p, &body).await }
2286        })
2287        .layer(axum::middleware::from_fn(
2288            oauth_token_cache_headers_middleware,
2289        )),
2290    );
2291
2292    let proxy_register = proxy.clone();
2293    let router = router.route(
2294        "/register",
2295        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
2296            let p = proxy_register;
2297            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
2298        })
2299        .layer(axum::middleware::from_fn(
2300            oauth_token_cache_headers_middleware,
2301        )),
2302    );
2303
2304    let admin_routes_enabled = proxy.expose_admin_endpoints
2305        && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
2306    if proxy.expose_admin_endpoints
2307        && !proxy.require_auth_on_admin_endpoints
2308        && proxy.allow_unauthenticated_admin_endpoints
2309    {
2310        // M3 escape-hatch in effect: validate() let this through because
2311        // the operator explicitly opted in. Surface it loudly at startup
2312        // so the choice is auditable in logs.
2313        tracing::warn!(
2314            "OAuth introspect/revoke endpoints are unauthenticated by explicit \
2315             allow_unauthenticated_admin_endpoints opt-out; ensure an \
2316             authenticated reverse proxy fronts these routes"
2317        );
2318    }
2319
2320    let admin_router = if admin_routes_enabled {
2321        build_oauth_admin_router(proxy, http, auth_state)?
2322    } else {
2323        axum::Router::new()
2324    };
2325
2326    let router = router.merge(admin_router);
2327
2328    tracing::info!(
2329        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
2330        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
2331        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
2332    );
2333    Ok(router)
2334}
2335
2336/// Build the optional `/introspect` + `/revoke` admin sub-router.
2337///
2338/// Layered with [`oauth_token_cache_headers_middleware`] so RFC 6749 §5.1
2339/// / RFC 6750 §5.4 cache headers are emitted, and conditionally with the
2340/// auth middleware when `proxy.require_auth_on_admin_endpoints` is set.
2341#[cfg(feature = "oauth")]
2342fn build_oauth_admin_router(
2343    proxy: &crate::oauth::OAuthProxyConfig,
2344    http: crate::oauth::OauthHttpClient,
2345    auth_state: Option<&Arc<AuthState>>,
2346) -> Result<axum::Router, McpxError> {
2347    let mut admin_router = axum::Router::new();
2348    if proxy.introspection_url.is_some() {
2349        let proxy_introspect = proxy.clone();
2350        let introspect_http = http.clone();
2351        admin_router = admin_router.route(
2352            "/introspect",
2353            axum::routing::post(move |body: String| {
2354                let p = proxy_introspect.clone();
2355                let h = introspect_http.clone();
2356                async move { crate::oauth::handle_introspect(&h, &p, &body).await }
2357            }),
2358        );
2359    }
2360    if proxy.revocation_url.is_some() {
2361        let proxy_revoke = proxy.clone();
2362        let revoke_http = http;
2363        admin_router = admin_router.route(
2364            "/revoke",
2365            axum::routing::post(move |body: String| {
2366                let p = proxy_revoke.clone();
2367                let h = revoke_http.clone();
2368                async move { crate::oauth::handle_revoke(&h, &p, &body).await }
2369            }),
2370        );
2371    }
2372
2373    let admin_router = admin_router.layer(axum::middleware::from_fn(
2374        oauth_token_cache_headers_middleware,
2375    ));
2376
2377    if proxy.require_auth_on_admin_endpoints {
2378        let Some(state) = auth_state else {
2379            return Err(McpxError::Startup(
2380                "oauth proxy admin endpoints require auth state".into(),
2381            ));
2382        };
2383        let state_for_mw = Arc::clone(state);
2384        Ok(
2385            admin_router.layer(axum::middleware::from_fn(move |req, next| {
2386                let s = Arc::clone(&state_for_mw);
2387                auth_middleware(s, req, next)
2388            })),
2389        )
2390    } else {
2391        Ok(admin_router)
2392    }
2393}
2394
2395/// Build the host allow-list for rmcp's DNS rebinding protection.
2396///
2397/// Includes loopback hosts by default, then augments with host/authority
2398/// derived from `public_url` and the server bind address.
2399fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
2400    let mut hosts = vec![
2401        "localhost".to_owned(),
2402        "127.0.0.1".to_owned(),
2403        "::1".to_owned(),
2404    ];
2405
2406    if let Some(url) = public_url
2407        && let Ok(uri) = url.parse::<axum::http::Uri>()
2408        && let Some(authority) = uri.authority()
2409    {
2410        let host = authority.host().to_owned();
2411        if !hosts.iter().any(|h| h == &host) {
2412            hosts.push(host);
2413        }
2414
2415        let authority = authority.as_str().to_owned();
2416        if !hosts.iter().any(|h| h == &authority) {
2417            hosts.push(authority);
2418        }
2419    }
2420
2421    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
2422        && let Some(authority) = uri.authority()
2423    {
2424        let host = authority.host().to_owned();
2425        if !hosts.iter().any(|h| h == &host) {
2426            hosts.push(host);
2427        }
2428
2429        let authority = authority.as_str().to_owned();
2430        if !hosts.iter().any(|h| h == &authority) {
2431            hosts.push(authority);
2432        }
2433    }
2434
2435    hosts
2436}
2437
2438// - TLS support -
2439
2440/// Implement axum's `Connected` trait for `TlsConnInfo` so that
2441/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
2442/// over our custom `TlsListener`.
2443///
2444/// The identity is read directly from the wrapping
2445/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
2446/// between the TLS connection and its mTLS identity. This eliminates the
2447/// previous shared-map approach which was vulnerable to ephemeral-port
2448/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
2449/// pair could alias a stale entry).
2450impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
2451    for TlsConnInfo
2452{
2453    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
2454        let addr = *target.remote_addr();
2455        let identity = target.io().identity().cloned();
2456        Self::new(addr, identity)
2457    }
2458}
2459
2460/// Default per-handshake deadline on the TLS accept path. Prevents idle
2461/// or slow-loris connections from pinning handshake worker tasks (and
2462/// their semaphore permits) indefinitely.
2463///
2464/// Configurable since 1.9.0 via
2465/// [`McpServerConfig::with_tls_handshake_timeout`].
2466const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
2467
2468/// Default upper bound on concurrently in-flight TLS handshakes. When
2469/// saturated, the acceptor task stops pulling new connections from the
2470/// kernel backlog (backpressure) instead of accepting and dropping them
2471/// in user space.
2472///
2473/// Configurable since 1.9.0 via
2474/// [`McpServerConfig::with_max_concurrent_tls_handshakes`].
2475const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
2476
2477/// Capacity of the completed-handshake queue between the acceptor task and
2478/// `axum::serve`'s `accept()` loop. Handshake workers block on `send` when
2479/// the queue is full, so a slow accept loop back-pressures handshakes
2480/// rather than buffering completed connections unboundedly.
2481const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
2482
2483/// A TLS-wrapping listener that implements axum's `Listener` trait.
2484///
2485/// TCP accepts and TLS handshakes run on a dedicated background task: each
2486/// accepted connection's handshake is spawned onto its own worker task,
2487/// bounded by a configurable concurrent-handshake cap (default
2488/// [`DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES`]) and a per-handshake timeout
2489/// (default [`DEFAULT_TLS_HANDSHAKE_TIMEOUT`]). A slow or idle client
2490/// therefore cannot stall other connections behind a serialized inline
2491/// handshake.
2492///
2493/// When mTLS is configured, client certificates are verified against the
2494/// configured CA and the client identity is extracted at handshake time.
2495/// The extracted identity is bound to the connection itself via the
2496/// returned [`AuthenticatedTlsStream`], so it is impossible for an
2497/// unrelated connection to observe it.
2498struct TlsListener {
2499    /// Bound address, captured eagerly before the `TcpListener` moves into
2500    /// the acceptor task.
2501    local_addr: SocketAddr,
2502    /// Completed handshakes produced by the acceptor task's workers.
2503    rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2504    /// Background task driving TCP accepts and concurrent TLS handshakes.
2505    /// Aborted on drop so the listener releases its port deterministically.
2506    acceptor_task: tokio::task::JoinHandle<()>,
2507}
2508
2509impl TlsListener {
2510    fn new(
2511        inner: TcpListener,
2512        cert_path: &Path,
2513        key_path: &Path,
2514        mtls_config: Option<&MtlsConfig>,
2515        crl_set: Option<Arc<CrlSet>>,
2516        handshake_timeout: Duration,
2517        max_concurrent_handshakes: usize,
2518    ) -> anyhow::Result<Self> {
2519        // Install the ring crypto provider (ok to call multiple times).
2520        rustls::crypto::ring::default_provider()
2521            .install_default()
2522            .ok();
2523
2524        let certs = load_certs(cert_path)?;
2525        let key = load_key(key_path)?;
2526
2527        let mtls_default_role;
2528
2529        let tls_config = if let Some(mtls) = mtls_config {
2530            mtls_default_role = mtls.default_role.clone();
2531            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2532            {
2533                let Some(crl_set) = crl_set else {
2534                    return Err(anyhow::anyhow!(
2535                        "mTLS CRL verifier requested but CRL state was not initialized"
2536                    ));
2537                };
2538                Arc::new(DynamicClientCertVerifier::new(crl_set))
2539            } else {
2540                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2541                if mtls.required {
2542                    rustls::server::WebPkiClientVerifier::builder(root_store)
2543                        .build()
2544                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2545                } else {
2546                    rustls::server::WebPkiClientVerifier::builder(root_store)
2547                        .allow_unauthenticated()
2548                        .build()
2549                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2550                }
2551            };
2552
2553            tracing::info!(
2554                ca = %mtls.ca_cert_path.display(),
2555                required = mtls.required,
2556                crl_enabled = mtls.crl_enabled,
2557                "mTLS client auth configured"
2558            );
2559
2560            rustls::ServerConfig::builder_with_protocol_versions(&[
2561                &rustls::version::TLS12,
2562                &rustls::version::TLS13,
2563            ])
2564            .with_client_cert_verifier(verifier)
2565            .with_single_cert(certs, key)?
2566        } else {
2567            mtls_default_role = "viewer".to_owned();
2568            rustls::ServerConfig::builder_with_protocol_versions(&[
2569                &rustls::version::TLS12,
2570                &rustls::version::TLS13,
2571            ])
2572            .with_no_client_auth()
2573            .with_single_cert(certs, key)?
2574        };
2575
2576        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2577        tracing::info!(
2578            "TLS enabled (cert: {}, key: {})",
2579            cert_path.display(),
2580            key_path.display()
2581        );
2582        let local_addr = inner.local_addr()?;
2583        let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2584        let acceptor_task = tokio::spawn(run_tls_acceptor(
2585            inner,
2586            acceptor,
2587            mtls_default_role,
2588            tx,
2589            handshake_timeout,
2590            max_concurrent_handshakes,
2591        ));
2592        Ok(Self {
2593            local_addr,
2594            rx,
2595            acceptor_task,
2596        })
2597    }
2598
2599    /// Extract the mTLS client cert identity from a completed TLS handshake.
2600    /// Returns `None` if no client certificate was presented or if the
2601    /// certificate could not be parsed into an [`AuthIdentity`].
2602    fn extract_handshake_identity(
2603        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2604        default_role: &str,
2605        addr: SocketAddr,
2606    ) -> Option<AuthIdentity> {
2607        let (_, server_conn) = tls_stream.get_ref();
2608        let cert_der = server_conn.peer_certificates()?.first()?;
2609        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2610        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2611        Some(id)
2612    }
2613}
2614
2615/// Drive TCP accepts and concurrent TLS handshakes for [`TlsListener`].
2616///
2617/// Each accepted connection's handshake runs on its own worker task under
2618/// a permit from a `max_concurrent_handshakes`-sized semaphore and a
2619/// `handshake_timeout` deadline. Completed handshakes are pushed to `tx`;
2620/// failures and timeouts are logged at DEBUG and the connection dropped.
2621/// The loop exits when the owning [`TlsListener`] is dropped.
2622async fn run_tls_acceptor(
2623    listener: TcpListener,
2624    acceptor: tokio_rustls::TlsAcceptor,
2625    default_role: String,
2626    tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2627    handshake_timeout: Duration,
2628    max_concurrent_handshakes: usize,
2629) {
2630    let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2631    loop {
2632        // Acquire the permit BEFORE accepting: at saturation, pending
2633        // connections wait in the kernel backlog instead of being accepted
2634        // and then buffered or dropped in user space.
2635        let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2636            // The semaphore is never closed; defensive exit.
2637            return;
2638        };
2639        let (stream, addr) = match listener.accept().await {
2640            Ok(pair) => pair,
2641            Err(e) => {
2642                tracing::debug!("TCP accept error: {e}");
2643                continue;
2644            }
2645        };
2646        if tx.is_closed() {
2647            // The listener was dropped (shutdown): stop accepting.
2648            return;
2649        }
2650        let acceptor = acceptor.clone();
2651        let default_role = default_role.clone();
2652        let tx = tx.clone();
2653        tokio::spawn(async move {
2654            let _permit = permit;
2655            match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2656                Ok(Ok(tls_stream)) => {
2657                    let identity =
2658                        TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2659                    let wrapped = AuthenticatedTlsStream {
2660                        inner: tls_stream,
2661                        identity,
2662                    };
2663                    // The receiver only disappears during shutdown; discard
2664                    // the completed connection quietly rather than logging.
2665                    let _ = tx.send((wrapped, addr)).await;
2666                }
2667                Ok(Err(e)) => {
2668                    tracing::debug!("TLS handshake failed from {addr}: {e}");
2669                }
2670                Err(_elapsed) => {
2671                    tracing::debug!(
2672                        "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2673                    );
2674                }
2675            }
2676        });
2677    }
2678}
2679
2680/// A TLS stream paired with the mTLS identity extracted at handshake time.
2681///
2682/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
2683/// identity travels with the connection itself. This replaces the previous
2684/// shared `MtlsIdentities` map, eliminating the
2685/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
2686/// reuse and removing the need for an LRU eviction policy.
2687///
2688/// The wrapper is `Unpin` (its inner stream is `Unpin` because
2689/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
2690/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
2691pub(crate) struct AuthenticatedTlsStream {
2692    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2693    identity: Option<AuthIdentity>,
2694}
2695
2696impl AuthenticatedTlsStream {
2697    /// Returns the verified mTLS client identity, if any.
2698    #[must_use]
2699    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2700        self.identity.as_ref()
2701    }
2702}
2703
2704impl std::fmt::Debug for AuthenticatedTlsStream {
2705    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2706        f.debug_struct("AuthenticatedTlsStream")
2707            .field("identity", &self.identity.as_ref().map(|id| &id.name))
2708            .finish_non_exhaustive()
2709    }
2710}
2711
2712impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2713    fn poll_read(
2714        mut self: Pin<&mut Self>,
2715        cx: &mut std::task::Context<'_>,
2716        buf: &mut tokio::io::ReadBuf<'_>,
2717    ) -> std::task::Poll<std::io::Result<()>> {
2718        Pin::new(&mut self.inner).poll_read(cx, buf)
2719    }
2720}
2721
2722impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2723    fn poll_write(
2724        mut self: Pin<&mut Self>,
2725        cx: &mut std::task::Context<'_>,
2726        buf: &[u8],
2727    ) -> std::task::Poll<std::io::Result<usize>> {
2728        Pin::new(&mut self.inner).poll_write(cx, buf)
2729    }
2730
2731    fn poll_flush(
2732        mut self: Pin<&mut Self>,
2733        cx: &mut std::task::Context<'_>,
2734    ) -> std::task::Poll<std::io::Result<()>> {
2735        Pin::new(&mut self.inner).poll_flush(cx)
2736    }
2737
2738    fn poll_shutdown(
2739        mut self: Pin<&mut Self>,
2740        cx: &mut std::task::Context<'_>,
2741    ) -> std::task::Poll<std::io::Result<()>> {
2742        Pin::new(&mut self.inner).poll_shutdown(cx)
2743    }
2744
2745    fn poll_write_vectored(
2746        mut self: Pin<&mut Self>,
2747        cx: &mut std::task::Context<'_>,
2748        bufs: &[std::io::IoSlice<'_>],
2749    ) -> std::task::Poll<std::io::Result<usize>> {
2750        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2751    }
2752
2753    fn is_write_vectored(&self) -> bool {
2754        self.inner.is_write_vectored()
2755    }
2756}
2757
2758impl axum::serve::Listener for TlsListener {
2759    type Io = AuthenticatedTlsStream;
2760    type Addr = SocketAddr;
2761
2762    /// Yield the next fully-handshaken TLS connection.
2763    ///
2764    /// Cancel-safe: this is a plain `mpsc::Receiver::recv`, so cancelling
2765    /// the future (axum selects it against graceful shutdown) never loses
2766    /// a connection.
2767    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2768        if let Some(pair) = self.rx.recv().await {
2769            return pair;
2770        }
2771        // The channel only closes if the acceptor task terminated, which
2772        // means the TcpListener is gone and the OS already refuses new
2773        // connections. `Listener::accept` is infallible and panicking is
2774        // forbidden, so park forever: existing connections keep being
2775        // served and graceful shutdown still completes.
2776        tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2777        std::future::pending().await
2778    }
2779
2780    fn local_addr(&self) -> std::io::Result<Self::Addr> {
2781        Ok(self.local_addr)
2782    }
2783}
2784
2785impl Drop for TlsListener {
2786    fn drop(&mut self) {
2787        // Stop accepting immediately and release the bound port. In-flight
2788        // handshake workers notice the closed channel and exit quietly.
2789        self.acceptor_task.abort();
2790    }
2791}
2792
2793fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2794    use rustls::pki_types::pem::PemObject;
2795    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2796        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2797        .collect::<Result<_, _>>()
2798        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2799    anyhow::ensure!(
2800        !certs.is_empty(),
2801        "no certificates found in {}",
2802        path.display()
2803    );
2804    Ok(certs)
2805}
2806
2807fn load_client_auth_roots(
2808    path: &Path,
2809) -> anyhow::Result<(
2810    Vec<rustls::pki_types::CertificateDer<'static>>,
2811    Arc<RootCertStore>,
2812)> {
2813    let ca_certs = load_certs(path)?;
2814    let mut root_store = RootCertStore::empty();
2815    for cert in &ca_certs {
2816        root_store
2817            .add(cert.clone())
2818            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2819    }
2820
2821    Ok((ca_certs, Arc::new(root_store)))
2822}
2823
2824fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2825    use rustls::pki_types::pem::PemObject;
2826    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2827        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2828}
2829
2830#[allow(
2831    clippy::unused_async,
2832    reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2833)]
2834async fn healthz() -> impl IntoResponse {
2835    axum::Json(serde_json::json!({
2836        "status": "ok",
2837    }))
2838}
2839
2840/// Build the `/version` JSON payload for a given server name and version.
2841///
2842/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
2843/// read at compile time from the `RMCP_SERVER_KIT_BUILD_SHA`,
2844/// `RMCP_SERVER_KIT_BUILD_TIME`, and `RMCP_SERVER_KIT_RUSTC_VERSION` env
2845/// vars. Unset values resolve to `"unknown"`.
2846fn version_payload(name: &str, version: &str) -> serde_json::Value {
2847    serde_json::json!({
2848        "name": name,
2849        "version": version,
2850        "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2851        "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2852        "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2853        "mcpx_version": env!("CARGO_PKG_VERSION"),
2854    })
2855}
2856
2857/// Pre-serialize the `/version` payload to immutable bytes.
2858///
2859/// This is called once at router-build time so per-request handling can
2860/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
2861/// [`serde_json::Value`] on every hit.
2862///
2863/// Serialization of a flat `serde_json::Value` of static-string fields
2864/// cannot fail in practice; the fallback to `b"{}"` exists only to
2865/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
2866fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2867    let value = version_payload(name, version);
2868    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2869}
2870
2871async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2872    let status = check().await;
2873    let ready = status
2874        .get("ready")
2875        .and_then(serde_json::Value::as_bool)
2876        .unwrap_or(false);
2877    let code = if ready {
2878        axum::http::StatusCode::OK
2879    } else {
2880        axum::http::StatusCode::SERVICE_UNAVAILABLE
2881    };
2882    (code, axum::Json(status))
2883}
2884
2885/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2886///
2887/// On non-Unix platforms, only SIGINT is handled.
2888async fn shutdown_signal() {
2889    let ctrl_c = tokio::signal::ctrl_c();
2890
2891    #[cfg(unix)]
2892    {
2893        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2894            Ok(mut term) => {
2895                // cancel-safe: signal-listener futures are cancel-safe per
2896                // tokio docs; no partial state in either arm.
2897                tokio::select! {
2898                    _ = ctrl_c => {}
2899                    _ = term.recv() => {}
2900                }
2901            }
2902            Err(e) => {
2903                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2904                ctrl_c.await.ok();
2905            }
2906        }
2907    }
2908
2909    #[cfg(not(unix))]
2910    {
2911        ctrl_c.await.ok();
2912    }
2913}
2914
2915// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2916
2917/// Middleware that validates the `Origin` header on incoming HTTP requests.
2918///
2919/// Record HTTP request metrics (method, path, status, duration).
2920///
2921/// Also exposes the shared [`crate::metrics::McpMetrics`] handle to
2922/// inner middleware via a request extension, so the rate limiters can
2923/// increment `rmcp_server_kit_rate_limited_total` at their deny sites
2924/// (see [`crate::metrics::record_rate_limit_deny`]).
2925#[cfg(feature = "metrics")]
2926async fn metrics_middleware(
2927    metrics: Arc<crate::metrics::McpMetrics>,
2928    mut req: Request<Body>,
2929    next: Next,
2930) -> axum::response::Response {
2931    let method = req.method().to_string();
2932    let path = req.uri().path().to_owned();
2933    let start = std::time::Instant::now();
2934
2935    req.extensions_mut().insert(Arc::clone(&metrics));
2936    let response = next.run(req).await;
2937
2938    let status = response.status().as_u16().to_string();
2939    let duration = start.elapsed().as_secs_f64();
2940
2941    metrics
2942        .http_requests_total
2943        .with_label_values(&[&method, &path, &status])
2944        .inc();
2945    metrics
2946        .http_request_duration_seconds
2947        .with_label_values(&[&method, &path])
2948        .observe(duration);
2949
2950    response
2951}
2952
2953/// OWASP security header hardening applied to every response.
2954///
2955/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2956/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2957/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2958/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2959/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2960///
2961/// Each header's value can be customised via [`SecurityHeadersConfig`]
2962/// on [`McpServerConfig`]. See that type for the three-state semantic
2963/// (`None` = default, `Some("")` = omit, `Some(v)` = override).
2964async fn security_headers_middleware(
2965    is_tls: bool,
2966    cfg: Arc<SecurityHeadersConfig>,
2967    req: Request<Body>,
2968    next: Next,
2969) -> axum::response::Response {
2970    use axum::http::{HeaderName, header};
2971
2972    let mut resp = next.run(req).await;
2973    let headers = resp.headers_mut();
2974
2975    // Strip server identity headers to reduce information leakage.
2976    headers.remove(header::SERVER);
2977    headers.remove(HeaderName::from_static("x-powered-by"));
2978
2979    apply_security_header(
2980        headers,
2981        header::X_CONTENT_TYPE_OPTIONS,
2982        cfg.x_content_type_options.as_deref(),
2983        "nosniff",
2984    );
2985    apply_security_header(
2986        headers,
2987        header::X_FRAME_OPTIONS,
2988        cfg.x_frame_options.as_deref(),
2989        "deny",
2990    );
2991    apply_security_header(
2992        headers,
2993        header::CACHE_CONTROL,
2994        cfg.cache_control.as_deref(),
2995        "no-store, max-age=0",
2996    );
2997    apply_security_header(
2998        headers,
2999        header::REFERRER_POLICY,
3000        cfg.referrer_policy.as_deref(),
3001        "no-referrer",
3002    );
3003    apply_security_header(
3004        headers,
3005        HeaderName::from_static("cross-origin-opener-policy"),
3006        cfg.cross_origin_opener_policy.as_deref(),
3007        "same-origin",
3008    );
3009    apply_security_header(
3010        headers,
3011        HeaderName::from_static("cross-origin-resource-policy"),
3012        cfg.cross_origin_resource_policy.as_deref(),
3013        "same-origin",
3014    );
3015    apply_security_header(
3016        headers,
3017        HeaderName::from_static("cross-origin-embedder-policy"),
3018        cfg.cross_origin_embedder_policy.as_deref(),
3019        "require-corp",
3020    );
3021    apply_security_header(
3022        headers,
3023        HeaderName::from_static("permissions-policy"),
3024        cfg.permissions_policy.as_deref(),
3025        "accelerometer=(), camera=(), geolocation=(), microphone=()",
3026    );
3027    apply_security_header(
3028        headers,
3029        HeaderName::from_static("x-permitted-cross-domain-policies"),
3030        cfg.x_permitted_cross_domain_policies.as_deref(),
3031        "none",
3032    );
3033    apply_security_header(
3034        headers,
3035        HeaderName::from_static("content-security-policy"),
3036        cfg.content_security_policy.as_deref(),
3037        "default-src 'none'; frame-ancestors 'none'",
3038    );
3039    apply_security_header(
3040        headers,
3041        HeaderName::from_static("x-dns-prefetch-control"),
3042        cfg.x_dns_prefetch_control.as_deref(),
3043        "off",
3044    );
3045
3046    if is_tls {
3047        apply_security_header(
3048            headers,
3049            header::STRICT_TRANSPORT_SECURITY,
3050            cfg.strict_transport_security.as_deref(),
3051            "max-age=63072000; includeSubDomains",
3052        );
3053    }
3054
3055    resp
3056}
3057
3058/// Set a single security header on the response, honouring the
3059/// three-state override semantic (None = default, Some("") = omit,
3060/// Some(value) = override).
3061///
3062/// Defence-in-depth: if an override value somehow reaches this point
3063/// despite [`validate_security_headers`] having approved it (e.g. a
3064/// runtime mutation on a non-`Validated` field), we log at error level
3065/// and fall back to the static default rather than panicking. The
3066/// `Validated<McpServerConfig>` type makes that path unreachable in
3067/// well-typed code paths.
3068fn apply_security_header(
3069    headers: &mut axum::http::HeaderMap,
3070    name: axum::http::HeaderName,
3071    override_value: Option<&str>,
3072    default: &'static str,
3073) {
3074    use axum::http::HeaderValue;
3075
3076    match override_value {
3077        None => {
3078            headers.insert(name, HeaderValue::from_static(default));
3079        }
3080        Some("") => {
3081            // Operator explicitly opted out of this header.
3082        }
3083        Some(v) => match HeaderValue::from_str(v) {
3084            Ok(hv) => {
3085                headers.insert(name, hv);
3086            }
3087            Err(err) => {
3088                tracing::error!(
3089                    header = %name,
3090                    error = %err,
3091                    "invalid security header override reached middleware; using default"
3092                );
3093                headers.insert(name, HeaderValue::from_static(default));
3094            }
3095        },
3096    }
3097}
3098
3099/// Validate every non-empty entry in a [`SecurityHeadersConfig`].
3100///
3101/// - `None` and `Some("")` are accepted unconditionally (use-default and
3102///   omit, respectively).
3103/// - `Some(v)` is rejected if `axum::http::HeaderValue::from_str(v)` fails.
3104/// - `strict_transport_security` additionally rejects any value
3105///   containing `preload` (case-insensitive). Operators who genuinely
3106///   want to commit to the HSTS preload list must do so via a future
3107///   explicit `with_hsts_preload(true)` builder, not by smuggling
3108///   `preload` through this knob.
3109fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
3110    use axum::http::HeaderValue;
3111
3112    let fields: &[(&str, Option<&str>)] = &[
3113        (
3114            "x_content_type_options",
3115            cfg.x_content_type_options.as_deref(),
3116        ),
3117        ("x_frame_options", cfg.x_frame_options.as_deref()),
3118        ("cache_control", cfg.cache_control.as_deref()),
3119        ("referrer_policy", cfg.referrer_policy.as_deref()),
3120        (
3121            "cross_origin_opener_policy",
3122            cfg.cross_origin_opener_policy.as_deref(),
3123        ),
3124        (
3125            "cross_origin_resource_policy",
3126            cfg.cross_origin_resource_policy.as_deref(),
3127        ),
3128        (
3129            "cross_origin_embedder_policy",
3130            cfg.cross_origin_embedder_policy.as_deref(),
3131        ),
3132        ("permissions_policy", cfg.permissions_policy.as_deref()),
3133        (
3134            "x_permitted_cross_domain_policies",
3135            cfg.x_permitted_cross_domain_policies.as_deref(),
3136        ),
3137        (
3138            "content_security_policy",
3139            cfg.content_security_policy.as_deref(),
3140        ),
3141        (
3142            "x_dns_prefetch_control",
3143            cfg.x_dns_prefetch_control.as_deref(),
3144        ),
3145        (
3146            "strict_transport_security",
3147            cfg.strict_transport_security.as_deref(),
3148        ),
3149    ];
3150
3151    for (field, value) in fields {
3152        let Some(v) = value else { continue };
3153        if v.is_empty() {
3154            continue;
3155        }
3156        if let Err(err) = HeaderValue::from_str(v) {
3157            return Err(McpxError::Config(format!(
3158                "invalid security_headers.{field}: {err}"
3159            )));
3160        }
3161    }
3162
3163    if let Some(v) = cfg.strict_transport_security.as_deref()
3164        && !v.is_empty()
3165        && v.to_ascii_lowercase().contains("preload")
3166    {
3167        return Err(McpxError::Config(format!(
3168            "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
3169             HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
3170        )));
3171    }
3172
3173    Ok(())
3174}
3175
3176/// Append RFC 6749 §5.1 / RFC 6750 §5.4 cache and `Vary` headers required
3177/// on OAuth token-issuing responses.
3178///
3179/// `Cache-Control: no-store, max-age=0` is already applied globally by
3180/// [`security_headers_middleware`]; this middleware adds:
3181///
3182/// - `Pragma: no-cache` -- mandated by RFC 6749 §5.1 for HTTP/1.0 caches.
3183/// - `Vary: Authorization` -- mandated by RFC 6750 §5.4 for endpoints
3184///   whose response depends on the `Authorization` header.
3185///
3186/// Applied only to the OAuth proxy token-class endpoints (`/token`,
3187/// `/register`, `/introspect`, `/revoke`). `Vary` is appended (not
3188/// inserted) so any `Vary` value already present (e.g. `Accept-Encoding`
3189/// from a compression layer, or `Origin` from a CORS layer) is preserved.
3190#[cfg(feature = "oauth")]
3191async fn oauth_token_cache_headers_middleware(
3192    req: Request<Body>,
3193    next: Next,
3194) -> axum::response::Response {
3195    use axum::http::{HeaderValue, header};
3196
3197    let mut resp = next.run(req).await;
3198    let headers = resp.headers_mut();
3199    headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
3200    headers.append(header::VARY, HeaderValue::from_static("Authorization"));
3201    resp
3202}
3203
3204/// Normalize peer-address request extensions across listener branches.
3205///
3206/// The make-service installs `ConnectInfo<SocketAddr>` on the plain
3207/// listener but `ConnectInfo<TlsConnInfo>` on the TLS listener (the
3208/// latter additionally carries the connection-bound mTLS identity and
3209/// stays `pub(crate)` — see the anti-aliasing rationale on
3210/// [`TlsConnInfo`]). Application routes — in particular those merged via
3211/// [`McpServerConfig::with_extra_router`], which bypass the auth
3212/// middleware and its private fallback — could therefore not read the
3213/// peer address under TLS.
3214///
3215/// This middleware makes both branches look identical to every route and
3216/// inner middleware:
3217///
3218/// 1. mirrors the TLS peer address into `ConnectInfo<SocketAddr>` when
3219///    (and only when) it is absent, so stock axum-ecosystem extractors
3220///    work unmodified,
3221/// 2. inserts the framework-owned [`PeerAddr`] extension on both
3222///    branches, and
3223/// 3. inserts the resolved [`ClientIp`] extension: the direct peer's IP,
3224///    unless trusted-forwarder mode is configured AND the direct peer is
3225///    a trusted proxy AND the forwarding chain resolves — every
3226///    ambiguous chain falls back to the direct peer with only a reason
3227///    code logged at `debug` (never raw header contents).
3228///
3229/// Precedence mirrors the auth middleware: an existing
3230/// `ConnectInfo<SocketAddr>` always wins and is never overwritten. The
3231/// peer address is deliberately not logged here.
3232async fn normalize_peer_addr_middleware(
3233    resolver: Option<Arc<ForwardResolver>>,
3234    mut req: Request<Body>,
3235    next: Next,
3236) -> axum::response::Response {
3237    let direct = req
3238        .extensions()
3239        .get::<ConnectInfo<SocketAddr>>()
3240        .map(|ci| ci.0);
3241    let from_tls = req
3242        .extensions()
3243        .get::<ConnectInfo<TlsConnInfo>>()
3244        .map(|ci| ci.0.addr);
3245    if let Some(addr) = direct.or(from_tls) {
3246        if direct.is_none() {
3247            req.extensions_mut().insert(ConnectInfo(addr));
3248        }
3249        req.extensions_mut().insert(PeerAddr::new(addr));
3250        let client_ip = match &resolver {
3251            Some(r) => {
3252                crate::forwarded::resolve_client_ip(addr.ip(), req.headers(), &r.trusted, r.mode)
3253                    .unwrap_or_else(|reason| {
3254                        tracing::debug!(
3255                            reason = ?reason,
3256                            "forwarded-header resolution fell back to direct peer"
3257                        );
3258                        addr.ip()
3259                    })
3260            }
3261            None => addr.ip(),
3262        };
3263        req.extensions_mut().insert(ClientIp::new(client_ip));
3264    }
3265    next.run(req).await
3266}
3267
3268/// Parse a trusted-proxy entry: a CIDR (`10.0.0.0/8`) or a bare IP
3269/// (normalized to a `/32` / `/128` host network).
3270fn parse_proxy_net(entry: &str) -> Option<ipnet::IpNet> {
3271    if let Ok(net) = entry.parse::<ipnet::IpNet>() {
3272        return Some(net);
3273    }
3274    entry.parse::<IpAddr>().ok().map(ipnet::IpNet::from)
3275}
3276
3277/// Rate-limit key for the current request: the resolved [`ClientIp`]
3278/// when present, else the direct peer from either `ConnectInfo` form.
3279/// All four built-in limiters key through this helper.
3280pub(crate) fn limiter_client_ip(extensions: &axum::http::Extensions) -> Option<IpAddr> {
3281    if let Some(client) = extensions.get::<ClientIp>() {
3282        return Some(client.ip);
3283    }
3284    extensions
3285        .get::<ConnectInfo<SocketAddr>>()
3286        .map(|ci| ci.0.ip())
3287        .or_else(|| {
3288            extensions
3289                .get::<ConnectInfo<TlsConnInfo>>()
3290                .map(|ci| ci.0.addr.ip())
3291        })
3292}
3293
3294/// Per-IP rate limiter for `extra_router` routes, keyed by the direct
3295/// socket peer address. Same memory-bounded machinery as the tool
3296/// limiter ([`crate::rbac`]).
3297pub(crate) type ExtraRouteRateLimiter = BoundedKeyedLimiter<IpAddr>;
3298
3299/// Cap on distinct source IPs tracked by the extra-route limiter.
3300/// Mirrors the tool limiter's bound: memory stays bounded at saturation
3301/// via idle-prune + LRU eviction, at the cost of shared-fate fairness
3302/// under key spray (an attacker churning many IPs can reset quieter
3303/// legitimate IPs to fresh buckets).
3304const EXTRA_ROUTE_MAX_TRACKED_KEYS: usize = 10_000;
3305
3306/// Idle-eviction window for the extra-route limiter (15 minutes),
3307/// mirroring the tool limiter.
3308const EXTRA_ROUTE_IDLE_EVICTION: Duration = Duration::from_mins(15);
3309
3310/// Build the per-IP limiter for `extra_router` routes.
3311///
3312/// `per_minute` and `burst` are validated nonzero by
3313/// [`McpServerConfig::validate`]; the `NonZeroU32` fallbacks here are
3314/// defensive only. `burst` overrides governor's default bucket capacity
3315/// (burst = rate).
3316fn build_extra_route_rate_limiter(
3317    per_minute: u32,
3318    burst: Option<u32>,
3319) -> Arc<ExtraRouteRateLimiter> {
3320    let rate = std::num::NonZeroU32::new(per_minute.max(1)).unwrap_or(std::num::NonZeroU32::MIN);
3321    let mut quota = governor::Quota::per_minute(rate);
3322    if let Some(b) = burst.and_then(std::num::NonZeroU32::new) {
3323        quota = quota.allow_burst(b);
3324    }
3325    Arc::new(BoundedKeyedLimiter::new(
3326        quota,
3327        EXTRA_ROUTE_MAX_TRACKED_KEYS,
3328        EXTRA_ROUTE_IDLE_EVICTION,
3329    ))
3330}
3331
3332/// Per-IP rate limit middleware for `extra_router` routes.
3333///
3334/// Applied to the application-supplied router **before** it is merged
3335/// into the top-level router, so it wraps exactly the extra routes
3336/// (and their fallback, if any) and nothing else — `/mcp`, health,
3337/// admin, and OAuth endpoints are never affected. Outer layers (origin
3338/// check, peer-address normalization, security headers, metrics) still
3339/// wrap these routes and run first, so both `ConnectInfo` forms are
3340/// populated by the time this middleware reads them.
3341///
3342/// Semantics mirror the tool/auth limiters exactly: keyed by the
3343/// direct peer `IpAddr` (no `X-Forwarded-For`), fail-open when no peer
3344/// address is present (cannot happen under [`serve`]), and on limit a
3345/// plain-text 429 via [`McpxError::RateLimitedFor`] carrying a
3346/// `Retry-After` header (delta-seconds), consistent with every other
3347/// limiter in the crate.
3348///
3349/// `exempt` holds raw exact-match paths (validated at config time)
3350/// checked against `req.uri().path()` **before** key extraction:
3351/// exempt requests consume no limiter budget and produce no deny
3352/// telemetry. Fail-closed — any non-listed path stays limited.
3353async fn extra_route_rate_limit_middleware(
3354    limiter: Arc<ExtraRouteRateLimiter>,
3355    exempt: Arc<std::collections::HashSet<String>>,
3356    req: Request<Body>,
3357    next: Next,
3358) -> axum::response::Response {
3359    if exempt.contains(req.uri().path()) {
3360        return next.run(req).await;
3361    }
3362    let peer_ip: Option<IpAddr> = limiter_client_ip(req.extensions());
3363    if let Some(ip) = peer_ip
3364        && let Err(wait) = limiter.check_key_wait(&ip)
3365    {
3366        #[cfg(feature = "metrics")]
3367        crate::metrics::record_rate_limit_deny(req.extensions(), "extra_route");
3368        tracing::warn!(%ip, "extra route request rate limited");
3369        return McpxError::RateLimitedFor {
3370            message: "too many requests to application routes from this source".into(),
3371            retry_after: wait,
3372        }
3373        .into_response();
3374    }
3375    next.run(req).await
3376}
3377
3378/// Per the MCP spec: if the Origin header is present and its value is not in
3379/// the allowed list, respond with 403 Forbidden. Requests without an Origin
3380/// header are allowed through (e.g. non-browser clients like curl, SDKs).
3381async fn origin_check_middleware(
3382    allowed: Arc<[String]>,
3383    log_request_headers: bool,
3384    req: Request<Body>,
3385    next: Next,
3386) -> axum::response::Response {
3387    let method = req.method().clone();
3388    let path = req.uri().path().to_owned();
3389
3390    log_incoming_request(&method, &path, req.headers(), log_request_headers);
3391
3392    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
3393        let origin_str = origin.to_str().unwrap_or("");
3394        if !allowed.iter().any(|a| a == origin_str) {
3395            tracing::warn!(
3396                origin = origin_str,
3397                %method,
3398                %path,
3399                allowed = ?&*allowed,
3400                "rejected request: Origin not allowed"
3401            );
3402            return (
3403                axum::http::StatusCode::FORBIDDEN,
3404                "Forbidden: Origin not allowed",
3405            )
3406                .into_response();
3407        }
3408    }
3409    next.run(req).await
3410}
3411
3412/// Emit a DEBUG log for an incoming request, optionally including the full
3413/// (redacted) header set.
3414fn log_incoming_request(
3415    method: &axum::http::Method,
3416    path: &str,
3417    headers: &axum::http::HeaderMap,
3418    log_request_headers: bool,
3419) {
3420    if log_request_headers {
3421        tracing::debug!(
3422            %method,
3423            %path,
3424            headers = %format_request_headers_for_log(headers),
3425            "incoming request"
3426        );
3427    } else {
3428        tracing::debug!(%method, %path, "incoming request");
3429    }
3430}
3431
3432fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
3433    headers
3434        .iter()
3435        .map(|(k, v)| {
3436            let name = k.as_str();
3437            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
3438                format!("{name}: [REDACTED]")
3439            } else {
3440                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
3441            }
3442        })
3443        .collect::<Vec<_>>()
3444        .join(", ")
3445}
3446
3447// -- stdio transport --
3448
3449/// Serve an MCP server over stdin/stdout (stdio transport).
3450///
3451/// # Security warnings
3452///
3453/// - **No authentication**: the parent process has full, unrestricted access.
3454/// - **No RBAC**: all tools are available regardless of policy.
3455/// - **No TLS**: messages travel over OS pipes in plaintext.
3456/// - **Single client**: only the parent process can connect.
3457/// - **No Origin validation**: not applicable to stdio.
3458///
3459/// Use this only when the MCP client spawns the server as a trusted subprocess
3460/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
3461/// use `serve()` (Streamable HTTP) instead.
3462///
3463/// # Errors
3464///
3465/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
3466/// transport disconnects unexpectedly.
3467// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
3468// macro expansion in this 18-line function (info/warn/info + two matches).
3469// There is nothing meaningful to extract; the allow stays.
3470#[allow(
3471    clippy::cognitive_complexity,
3472    reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
3473)]
3474pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
3475where
3476    H: ServerHandler + 'static,
3477{
3478    use rmcp::ServiceExt as _;
3479
3480    tracing::info!("stdio transport: serving on stdin/stdout");
3481    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
3482
3483    let transport = rmcp::transport::io::stdio();
3484
3485    let service = handler
3486        .serve(transport)
3487        .await
3488        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
3489
3490    if let Err(e) = service.waiting().await {
3491        tracing::warn!(error = %e, "stdio session ended with error");
3492    }
3493    tracing::info!("stdio session ended");
3494    Ok(())
3495}
3496
3497#[cfg(test)]
3498mod tests {
3499    #![allow(
3500        clippy::unwrap_used,
3501        clippy::expect_used,
3502        clippy::panic,
3503        clippy::indexing_slicing,
3504        clippy::unwrap_in_result,
3505        clippy::print_stdout,
3506        clippy::print_stderr,
3507        deprecated,
3508        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
3509    )]
3510    use std::{sync::Arc, time::Duration};
3511
3512    use axum::{
3513        body::Body,
3514        http::{Request, StatusCode, header},
3515        response::IntoResponse,
3516    };
3517    use http_body_util::BodyExt;
3518    use tower::ServiceExt as _;
3519
3520    use super::*;
3521
3522    // -- McpServerConfig --
3523
3524    #[test]
3525    fn server_config_new_defaults() {
3526        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
3527        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
3528        assert_eq!(cfg.name, "test-server");
3529        assert_eq!(cfg.version, "1.0.0");
3530        assert!(cfg.tls_cert_path.is_none());
3531        assert!(cfg.tls_key_path.is_none());
3532        assert!(cfg.auth.is_none());
3533        assert!(cfg.rbac.is_none());
3534        assert!(cfg.allowed_origins.is_empty());
3535        assert!(cfg.tool_rate_limit.is_none());
3536        assert!(cfg.readiness_check.is_none());
3537        assert_eq!(cfg.max_request_body, 1024 * 1024);
3538        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
3539        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
3540        assert!(!cfg.log_request_headers);
3541        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
3542        assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
3543    }
3544
3545    #[test]
3546    fn tls_handshake_builders_set_fields() {
3547        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3548            .with_tls_handshake_timeout(Duration::from_secs(3))
3549            .with_max_concurrent_tls_handshakes(64);
3550        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
3551        assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
3552    }
3553
3554    #[test]
3555    fn validate_rejects_zero_tls_handshake_timeout() {
3556        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3557            .with_tls_handshake_timeout(Duration::ZERO);
3558        let err = cfg.validate().expect_err("zero handshake timeout");
3559        assert!(err.to_string().contains("tls_handshake_timeout"));
3560    }
3561
3562    #[test]
3563    fn validate_rejects_zero_max_concurrent_tls_handshakes() {
3564        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3565            .with_max_concurrent_tls_handshakes(0);
3566        let err = cfg.validate().expect_err("zero handshake concurrency");
3567        assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
3568    }
3569
3570    #[test]
3571    fn validate_consumes_and_proves() {
3572        // Valid config -> Validated wrapper, original is consumed.
3573        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3574        let validated = cfg.validate().expect("valid config");
3575        // as_inner() gives read-only access to inner fields.
3576        assert_eq!(validated.as_inner().name, "test-server");
3577        // into_inner recovers the raw value.
3578        let raw = validated.into_inner();
3579        assert_eq!(raw.name, "test-server");
3580
3581        // Invalid config (zero max_request_body) -> Err.
3582        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3583        bad.max_request_body = 0;
3584        assert!(bad.validate().is_err(), "zero body cap must fail validate");
3585    }
3586
3587    #[test]
3588    fn validate_rejects_zero_max_concurrent_requests() {
3589        let cfg =
3590            McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
3591        let err = cfg.validate().expect_err("zero concurrency cap must fail");
3592        assert!(
3593            format!("{err}").contains("max_concurrent_requests"),
3594            "error should mention max_concurrent_requests, got: {err}"
3595        );
3596    }
3597
3598    #[test]
3599    fn validate_rejects_zero_max_tracked_keys() {
3600        // Defaults mirror auth::default_max_attempts / default_idle_eviction
3601        // (module-private in auth.rs); spelled out here for review clarity.
3602        let rl = crate::auth::RateLimitConfig {
3603            max_attempts_per_minute: 30,
3604            pre_auth_max_per_minute: None,
3605            max_tracked_keys: 0,
3606            idle_eviction: Duration::from_secs(15 * 60),
3607            burst: None,
3608            pre_auth_burst: None,
3609        };
3610        let auth_cfg = AuthConfig {
3611            enabled: true,
3612            api_keys: Vec::new(),
3613            mtls: None,
3614            rate_limit: Some(rl),
3615            #[cfg(feature = "oauth")]
3616            oauth: None,
3617        };
3618        let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
3619        let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
3620        assert!(
3621            format!("{err}").contains("max_tracked_keys"),
3622            "error should mention max_tracked_keys, got: {err}"
3623        );
3624    }
3625
3626    #[test]
3627    fn derive_allowed_hosts_includes_public_host() {
3628        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
3629        assert!(
3630            hosts.iter().any(|h| h == "mcp.example.com"),
3631            "public_url host must be allowed"
3632        );
3633    }
3634
3635    #[test]
3636    fn derive_allowed_hosts_includes_bind_authority() {
3637        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
3638        assert!(
3639            hosts.iter().any(|h| h == "127.0.0.1"),
3640            "bind host must be allowed"
3641        );
3642        assert!(
3643            hosts.iter().any(|h| h == "127.0.0.1:8080"),
3644            "bind authority must be allowed"
3645        );
3646    }
3647
3648    // -- healthz --
3649
3650    #[tokio::test]
3651    async fn healthz_returns_ok_json() {
3652        let resp = healthz().await.into_response();
3653        assert_eq!(resp.status(), StatusCode::OK);
3654        let body = resp.into_body().collect().await.unwrap().to_bytes();
3655        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3656        assert_eq!(json["status"], "ok");
3657        assert!(
3658            json.get("name").is_none(),
3659            "healthz must not expose server name"
3660        );
3661        assert!(
3662            json.get("version").is_none(),
3663            "healthz must not expose version"
3664        );
3665    }
3666
3667    // -- readyz --
3668
3669    #[tokio::test]
3670    async fn readyz_returns_ok_when_ready() {
3671        let check: ReadinessCheck =
3672            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
3673        let resp = readyz(check).await.into_response();
3674        assert_eq!(resp.status(), StatusCode::OK);
3675        let body = resp.into_body().collect().await.unwrap().to_bytes();
3676        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3677        assert_eq!(json["ready"], true);
3678        assert!(
3679            json.get("name").is_none(),
3680            "readyz must not expose server name"
3681        );
3682        assert!(
3683            json.get("version").is_none(),
3684            "readyz must not expose version"
3685        );
3686        assert_eq!(json["db"], "connected");
3687    }
3688
3689    #[tokio::test]
3690    async fn readyz_returns_503_when_not_ready() {
3691        let check: ReadinessCheck =
3692            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3693        let resp = readyz(check).await.into_response();
3694        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3695    }
3696
3697    #[tokio::test]
3698    async fn readyz_returns_503_when_ready_missing() {
3699        let check: ReadinessCheck =
3700            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3701        let resp = readyz(check).await.into_response();
3702        // Missing "ready" field defaults to false -> 503
3703        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3704    }
3705
3706    // -- normalize_peer_addr_middleware / PeerAddr --
3707
3708    /// Build a test router that reports the request's peer-address
3709    /// extensions as `"<ConnectInfo>|<PeerAddr>"` (empty when absent).
3710    fn peer_probe_router() -> axum::Router {
3711        async fn probe(req: Request<Body>) -> String {
3712            let ci = req
3713                .extensions()
3714                .get::<ConnectInfo<SocketAddr>>()
3715                .map(|c| c.0.to_string())
3716                .unwrap_or_default();
3717            let pa = req
3718                .extensions()
3719                .get::<PeerAddr>()
3720                .map(|p| p.addr.to_string())
3721                .unwrap_or_default();
3722            format!("{ci}|{pa}")
3723        }
3724        axum::Router::new()
3725            .route("/probe", axum::routing::get(probe))
3726            .layer(axum::middleware::from_fn(|req, next| {
3727                normalize_peer_addr_middleware(None, req, next)
3728            }))
3729    }
3730
3731    async fn body_string(resp: axum::response::Response) -> String {
3732        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
3733        String::from_utf8(bytes.to_vec()).unwrap()
3734    }
3735
3736    #[tokio::test]
3737    async fn normalize_preserves_existing_connect_info_and_mirrors_peer_addr() {
3738        // Precedence proof: when both extensions exist with DIFFERENT
3739        // addresses, ConnectInfo<SocketAddr> wins and is never overwritten.
3740        let plain: SocketAddr = "10.0.0.1:1111".parse().unwrap();
3741        let tls: SocketAddr = "10.0.0.2:2222".parse().unwrap();
3742        let req = Request::builder()
3743            .uri("/probe")
3744            .extension(ConnectInfo(plain))
3745            .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3746            .body(Body::empty())
3747            .unwrap();
3748        let resp = peer_probe_router().oneshot(req).await.unwrap();
3749        assert_eq!(resp.status(), StatusCode::OK);
3750        assert_eq!(body_string(resp).await, format!("{plain}|{plain}"));
3751    }
3752
3753    #[tokio::test]
3754    async fn normalize_inserts_connect_info_and_peer_addr_from_tls() {
3755        let tls: SocketAddr = "192.168.1.7:50443".parse().unwrap();
3756        let req = Request::builder()
3757            .uri("/probe")
3758            .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3759            .body(Body::empty())
3760            .unwrap();
3761        let resp = peer_probe_router().oneshot(req).await.unwrap();
3762        assert_eq!(resp.status(), StatusCode::OK);
3763        assert_eq!(body_string(resp).await, format!("{tls}|{tls}"));
3764    }
3765
3766    #[tokio::test]
3767    async fn normalize_no_op_without_any_connect_info() {
3768        let req = Request::builder()
3769            .uri("/probe")
3770            .body(Body::empty())
3771            .unwrap();
3772        let resp = peer_probe_router().oneshot(req).await.unwrap();
3773        assert_eq!(resp.status(), StatusCode::OK);
3774        assert_eq!(body_string(resp).await, "|");
3775    }
3776
3777    #[tokio::test]
3778    async fn peer_addr_extractor_rejects_when_absent() {
3779        async fn h(peer: PeerAddr) -> String {
3780            peer.addr.to_string()
3781        }
3782        let app = axum::Router::new().route("/p", axum::routing::get(h));
3783        let req = Request::builder().uri("/p").body(Body::empty()).unwrap();
3784        let resp = app.oneshot(req).await.unwrap();
3785        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
3786    }
3787
3788    #[tokio::test]
3789    async fn peer_addr_extractor_returns_value_when_present() {
3790        async fn h(peer: PeerAddr) -> String {
3791            peer.addr.to_string()
3792        }
3793        let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
3794        let app = axum::Router::new().route("/p", axum::routing::get(h));
3795        let req = Request::builder()
3796            .uri("/p")
3797            .extension(PeerAddr::new(addr))
3798            .body(Body::empty())
3799            .unwrap();
3800        let resp = app.oneshot(req).await.unwrap();
3801        assert_eq!(resp.status(), StatusCode::OK);
3802        assert_eq!(body_string(resp).await, addr.to_string());
3803    }
3804
3805    #[tokio::test]
3806    async fn peer_addr_via_extension_extractor() {
3807        async fn h(axum::Extension(peer): axum::Extension<PeerAddr>) -> String {
3808            peer.addr.to_string()
3809        }
3810        let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
3811        let app = axum::Router::new().route("/p", axum::routing::get(h));
3812        let req = Request::builder()
3813            .uri("/p")
3814            .extension(PeerAddr::new(addr))
3815            .body(Body::empty())
3816            .unwrap();
3817        let resp = app.oneshot(req).await.unwrap();
3818        assert_eq!(resp.status(), StatusCode::OK);
3819        assert_eq!(body_string(resp).await, addr.to_string());
3820    }
3821
3822    // -- extra_route_rate_limit_middleware --
3823
3824    /// Probe router with the extra-route limiter installed, mirroring
3825    /// the layer-before-merge wiring in `build_app_router`.
3826    fn limited_router(per_minute: u32) -> axum::Router {
3827        limited_router_with_burst(per_minute, None)
3828    }
3829
3830    /// Probe router with an explicit burst capacity.
3831    fn limited_router_with_burst(per_minute: u32, burst: Option<u32>) -> axum::Router {
3832        limited_router_full(per_minute, burst, &[])
3833    }
3834
3835    /// Probe router with explicit burst and exempt paths. `/limited`
3836    /// and `/exempt` are both registered so exemption interplay can be
3837    /// asserted on one limiter instance.
3838    fn limited_router_full(
3839        per_minute: u32,
3840        burst: Option<u32>,
3841        exempt_paths: &[&str],
3842    ) -> axum::Router {
3843        let limiter = build_extra_route_rate_limiter(per_minute, burst);
3844        let exempt: Arc<std::collections::HashSet<String>> =
3845            Arc::new(exempt_paths.iter().map(|s| (*s).to_owned()).collect());
3846        axum::Router::new()
3847            .route("/limited", axum::routing::get(|| async { "ok" }))
3848            .route("/exempt", axum::routing::get(|| async { "ok" }))
3849            .layer(axum::middleware::from_fn(move |req, next| {
3850                let l = Arc::clone(&limiter);
3851                let e = Arc::clone(&exempt);
3852                extra_route_rate_limit_middleware(l, e, req, next)
3853            }))
3854    }
3855
3856    fn limited_req(ip: &str) -> Request<Body> {
3857        limited_req_to(ip, "/limited")
3858    }
3859
3860    fn limited_req_to(ip: &str, path: &str) -> Request<Body> {
3861        let addr: SocketAddr = format!("{ip}:40000").parse().unwrap();
3862        Request::builder()
3863            .uri(path)
3864            .extension(ConnectInfo(addr))
3865            .body(Body::empty())
3866            .unwrap()
3867    }
3868
3869    #[tokio::test]
3870    async fn extra_route_limiter_denies_over_quota() {
3871        let app = limited_router(2);
3872        for i in 0..2 {
3873            let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3874            assert_eq!(resp.status(), StatusCode::OK, "request {i} should pass");
3875        }
3876        let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3877        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3878        let body = body_string(resp).await;
3879        assert!(
3880            body.contains("too many requests to application routes"),
3881            "deny body should match the limiter message, got: {body}"
3882        );
3883    }
3884
3885    #[tokio::test]
3886    async fn extra_route_limiter_isolates_keys() {
3887        let app = limited_router(2);
3888        for _ in 0..2 {
3889            let resp = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3890            assert_eq!(resp.status(), StatusCode::OK);
3891        }
3892        let exhausted = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3893        assert_eq!(exhausted.status(), StatusCode::TOO_MANY_REQUESTS);
3894        // A different source IP still has a fresh bucket.
3895        let other = app.clone().oneshot(limited_req("10.3.3.3")).await.unwrap();
3896        assert_eq!(other.status(), StatusCode::OK);
3897    }
3898
3899    #[tokio::test]
3900    async fn extra_route_limiter_fails_open_without_peer() {
3901        let app = limited_router(1);
3902        for i in 0..3 {
3903            let req = Request::builder()
3904                .uri("/limited")
3905                .body(Body::empty())
3906                .unwrap();
3907            let resp = app.clone().oneshot(req).await.unwrap();
3908            assert_eq!(
3909                resp.status(),
3910                StatusCode::OK,
3911                "request {i} should fail open"
3912            );
3913        }
3914    }
3915
3916    #[tokio::test]
3917    async fn extra_route_limiter_extracts_tls_conn_info() {
3918        let app = limited_router(2);
3919        let mk = || {
3920            let addr: SocketAddr = "192.168.9.9:55555".parse().unwrap();
3921            Request::builder()
3922                .uri("/limited")
3923                .extension(ConnectInfo(TlsConnInfo::new(addr, None)))
3924                .body(Body::empty())
3925                .unwrap()
3926        };
3927        for _ in 0..2 {
3928            assert_eq!(
3929                app.clone().oneshot(mk()).await.unwrap().status(),
3930                StatusCode::OK
3931            );
3932        }
3933        let resp = app.clone().oneshot(mk()).await.unwrap();
3934        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3935    }
3936
3937    #[tokio::test]
3938    async fn extra_route_limiter_exempt_path_bypasses_quota() {
3939        // rate=1: a single non-exempt request exhausts the bucket, yet
3940        // repeated exempt-path requests all pass and consume no budget.
3941        let app = limited_router_full(1, None, &["/exempt"]);
3942        for i in 0..5 {
3943            let resp = app
3944                .clone()
3945                .oneshot(limited_req_to("10.6.6.6", "/exempt"))
3946                .await
3947                .unwrap();
3948            assert_eq!(resp.status(), StatusCode::OK, "exempt request {i}");
3949        }
3950        // Budget untouched by exempt traffic: first limited request OK…
3951        let resp = app.clone().oneshot(limited_req("10.6.6.6")).await.unwrap();
3952        assert_eq!(resp.status(), StatusCode::OK);
3953        // …second is denied (exemption did not leak onto /limited).
3954        let resp = app.clone().oneshot(limited_req("10.6.6.6")).await.unwrap();
3955        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3956    }
3957
3958    #[tokio::test]
3959    async fn extra_route_limiter_exemption_is_raw_exact_match() {
3960        // Trailing-slash and case variants are NOT exempt (fail-closed:
3961        // a mismatch keeps the request limited, never the reverse).
3962        let app = limited_router_full(1, None, &["/exempt"]);
3963        let ok = app
3964            .clone()
3965            .oneshot(limited_req_to("10.7.7.7", "/exempt/"))
3966            .await
3967            .unwrap();
3968        assert_eq!(
3969            ok.status(),
3970            StatusCode::NOT_FOUND,
3971            "variant path routes 404"
3972        );
3973        // The variant consumed limiter budget (it was not exempt):
3974        let denied = app
3975            .clone()
3976            .oneshot(limited_req_to("10.7.7.7", "/limited"))
3977            .await
3978            .unwrap();
3979        assert_eq!(denied.status(), StatusCode::TOO_MANY_REQUESTS);
3980    }
3981
3982    #[cfg(feature = "metrics")]
3983    #[tokio::test]
3984    async fn extra_route_limiter_deny_increments_counter_exempt_does_not() {
3985        let metrics = Arc::new(crate::metrics::McpMetrics::new().unwrap());
3986        let app = limited_router_full(1, None, &["/exempt"]);
3987        let mk = |path: &str| {
3988            let addr: SocketAddr = "10.8.8.8:40000".parse().unwrap();
3989            Request::builder()
3990                .uri(path)
3991                .extension(ConnectInfo(addr))
3992                .extension(Arc::clone(&metrics))
3993                .body(Body::empty())
3994                .unwrap()
3995        };
3996        let counter = || {
3997            metrics
3998                .rate_limited_total
3999                .with_label_values(&["extra_route"])
4000                .get()
4001        };
4002        // Exempt traffic: no budget, no counter.
4003        for _ in 0..3 {
4004            assert_eq!(
4005                app.clone().oneshot(mk("/exempt")).await.unwrap().status(),
4006                StatusCode::OK
4007            );
4008        }
4009        assert_eq!(counter(), 0, "exempt requests must not count as denies");
4010        // Exhaust then deny: counter increments exactly on the deny.
4011        assert_eq!(
4012            app.clone().oneshot(mk("/limited")).await.unwrap().status(),
4013            StatusCode::OK
4014        );
4015        assert_eq!(counter(), 0);
4016        assert_eq!(
4017            app.clone().oneshot(mk("/limited")).await.unwrap().status(),
4018            StatusCode::TOO_MANY_REQUESTS
4019        );
4020        assert_eq!(counter(), 1, "deny must increment the extra_route label");
4021    }
4022
4023    #[test]
4024    fn validate_rejects_exempt_paths_without_base_knob() {
4025        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4026            .with_extra_route_rate_limit_exempt_paths(["/ok"]);
4027        let err = cfg.validate().expect_err("exempt paths without rate limit");
4028        assert!(err.to_string().contains("requires extra_route_rate_limit"));
4029    }
4030
4031    #[test]
4032    fn validate_rejects_malformed_exempt_paths() {
4033        for bad in ["", "no-slash"] {
4034            let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4035                .with_extra_route_rate_limit(10)
4036                .with_extra_route_rate_limit_exempt_paths([bad]);
4037            let err = cfg.validate().expect_err("malformed exempt path");
4038            assert!(
4039                err.to_string()
4040                    .contains("must be non-empty and start with '/'"),
4041                "entry {bad:?}: {err}"
4042            );
4043        }
4044    }
4045
4046    #[test]
4047    fn validate_accepts_wellformed_exempt_paths() {
4048        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4049            .with_extra_route_rate_limit(10)
4050            .with_extra_route_rate_limit_exempt_paths(["/.well-known/oauth-authorization-server"]);
4051        assert!(cfg.validate().is_ok());
4052    }
4053
4054    #[test]
4055    fn validate_rejects_zero_extra_route_rate_limit() {
4056        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
4057            .with_extra_route_rate_limit(0);
4058        let err = cfg.validate().expect_err("zero extra route rate limit");
4059        assert!(err.to_string().contains("extra_route_rate_limit"));
4060    }
4061
4062    #[tokio::test]
4063    async fn extra_route_limiter_burst_allows_initial_spike() {
4064        let app = limited_router_with_burst(1, Some(3));
4065        for i in 0..3 {
4066            let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
4067            assert_eq!(resp.status(), StatusCode::OK, "burst request {i}");
4068        }
4069        let resp = app.clone().oneshot(limited_req("10.4.4.4")).await.unwrap();
4070        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
4071    }
4072
4073    #[tokio::test]
4074    async fn extra_route_limiter_deny_sets_retry_after() {
4075        let app = limited_router(1);
4076        let ok = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
4077        assert_eq!(ok.status(), StatusCode::OK);
4078        let denied = app.clone().oneshot(limited_req("10.5.5.5")).await.unwrap();
4079        assert_eq!(denied.status(), StatusCode::TOO_MANY_REQUESTS);
4080        let retry_after = denied
4081            .headers()
4082            .get(header::RETRY_AFTER)
4083            .expect("Retry-After present")
4084            .to_str()
4085            .unwrap()
4086            .parse::<u64>()
4087            .unwrap();
4088        assert!(retry_after >= 1, "delta-seconds must be >= 1");
4089    }
4090
4091    #[test]
4092    fn validate_rejects_zero_burst_knobs() {
4093        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4094            .with_tool_rate_limit(10)
4095            .with_tool_rate_limit_burst(0)
4096            .validate()
4097            .expect_err("zero tool burst");
4098        assert!(err.to_string().contains("tool_rate_limit_burst"));
4099
4100        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4101            .with_extra_route_rate_limit(10)
4102            .with_extra_route_rate_limit_burst(0)
4103            .validate()
4104            .expect_err("zero extra route burst");
4105        assert!(err.to_string().contains("extra_route_rate_limit_burst"));
4106    }
4107
4108    #[test]
4109    fn validate_rejects_orphan_burst_knobs() {
4110        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4111            .with_tool_rate_limit_burst(5)
4112            .validate()
4113            .expect_err("orphan tool burst");
4114        assert!(err.to_string().contains("requires tool_rate_limit"));
4115
4116        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4117            .with_extra_route_rate_limit_burst(5)
4118            .validate()
4119            .expect_err("orphan extra route burst");
4120        assert!(err.to_string().contains("requires extra_route_rate_limit"));
4121    }
4122
4123    #[test]
4124    fn validate_rejects_zero_auth_bursts() {
4125        let auth = AuthConfig::with_keys(vec![])
4126            .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_burst(0));
4127        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4128            .with_auth(auth)
4129            .validate()
4130            .expect_err("zero auth burst");
4131        assert!(err.to_string().contains("rate_limit.burst"));
4132
4133        let auth = AuthConfig::with_keys(vec![])
4134            .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(0));
4135        let err = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4136            .with_auth(auth)
4137            .validate()
4138            .expect_err("zero pre-auth burst");
4139        assert!(err.to_string().contains("pre_auth_burst"));
4140    }
4141
4142    /// `pre_auth_burst` without `pre_auth_max_per_minute` is LEGAL: the
4143    /// pre-auth base rate always resolves (max_attempts_per_minute x 10).
4144    #[test]
4145    fn validate_accepts_pre_auth_burst_without_explicit_pre_auth_rate() {
4146        let auth = AuthConfig::with_keys(vec![])
4147            .with_rate_limit(crate::auth::RateLimitConfig::new(10).with_pre_auth_burst(50));
4148        let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_auth(auth);
4149        assert!(cfg.validate().is_ok(), "pre_auth_burst has no orphan rule");
4150    }
4151
4152    // -- trusted-forwarder mode (ClientIp / ForwardedHeaderMode) --
4153
4154    fn forward_resolver(trusted: &[&str], mode: ForwardedHeaderMode) -> Arc<ForwardResolver> {
4155        Arc::new(ForwardResolver {
4156            trusted: trusted.iter().map(|s| s.parse().unwrap()).collect(),
4157            mode,
4158        })
4159    }
4160
4161    /// Probe router reporting `"<PeerAddr ip>|<ClientIp>"`.
4162    fn forwarded_probe_router(resolver: Option<Arc<ForwardResolver>>) -> axum::Router {
4163        async fn probe(req: Request<Body>) -> String {
4164            let pa = req
4165                .extensions()
4166                .get::<PeerAddr>()
4167                .map(|p| p.addr.ip().to_string())
4168                .unwrap_or_default();
4169            let ci = req
4170                .extensions()
4171                .get::<ClientIp>()
4172                .map(|c| c.ip.to_string())
4173                .unwrap_or_default();
4174            format!("{pa}|{ci}")
4175        }
4176        axum::Router::new()
4177            .route("/probe", axum::routing::get(probe))
4178            .layer(axum::middleware::from_fn(move |req, next| {
4179                let r = resolver.clone();
4180                normalize_peer_addr_middleware(r, req, next)
4181            }))
4182    }
4183
4184    fn probe_req(peer: &str, header: Option<(&str, &str)>) -> Request<Body> {
4185        let addr: SocketAddr = peer.parse().unwrap();
4186        let mut builder = Request::builder()
4187            .uri("/probe")
4188            .extension(ConnectInfo(addr));
4189        if let Some((name, value)) = header {
4190            builder = builder.header(name, value);
4191        }
4192        builder.body(Body::empty()).unwrap()
4193    }
4194
4195    #[tokio::test]
4196    async fn client_ip_equals_direct_without_resolver() {
4197        let app = forwarded_probe_router(None);
4198        let resp = app
4199            .oneshot(probe_req(
4200                "10.1.2.3:4444",
4201                Some(("x-forwarded-for", "203.0.113.7")),
4202            ))
4203            .await
4204            .unwrap();
4205        assert_eq!(
4206            body_string(resp).await,
4207            "10.1.2.3|10.1.2.3",
4208            "feature off: header ignored, ClientIp == direct"
4209        );
4210    }
4211
4212    #[tokio::test]
4213    async fn client_ip_resolved_for_trusted_peer() {
4214        let app = forwarded_probe_router(Some(forward_resolver(
4215            &["10.0.0.0/8"],
4216            ForwardedHeaderMode::XForwardedFor,
4217        )));
4218        let resp = app
4219            .oneshot(probe_req(
4220                "10.0.0.1:9999",
4221                Some(("x-forwarded-for", "203.0.113.7")),
4222            ))
4223            .await
4224            .unwrap();
4225        assert_eq!(
4226            body_string(resp).await,
4227            "10.0.0.1|203.0.113.7",
4228            "PeerAddr stays direct while ClientIp resolves"
4229        );
4230    }
4231
4232    #[tokio::test]
4233    async fn client_ip_falls_back_to_direct_on_malformed_header() {
4234        let app = forwarded_probe_router(Some(forward_resolver(
4235            &["10.0.0.0/8"],
4236            ForwardedHeaderMode::XForwardedFor,
4237        )));
4238        let resp = app
4239            .oneshot(probe_req(
4240                "10.0.0.1:9999",
4241                Some(("x-forwarded-for", "not-an-ip")),
4242            ))
4243            .await
4244            .unwrap();
4245        assert_eq!(
4246            body_string(resp).await,
4247            "10.0.0.1|10.0.0.1",
4248            "malformed chain falls back to the direct peer"
4249        );
4250    }
4251
4252    #[test]
4253    fn forwarded_header_mode_deserializes_kebab_case() {
4254        #[derive(serde::Deserialize)]
4255        struct Wrapper {
4256            mode: ForwardedHeaderMode,
4257        }
4258        let w: Wrapper = toml::from_str(r#"mode = "x-forwarded-for""#).unwrap();
4259        assert_eq!(w.mode, ForwardedHeaderMode::XForwardedFor);
4260        let w: Wrapper = toml::from_str(r#"mode = "forwarded""#).unwrap();
4261        assert_eq!(w.mode, ForwardedHeaderMode::Forwarded);
4262        assert!(
4263            toml::from_str::<Wrapper>(r#"mode = "XForwardedFor""#).is_err(),
4264            "PascalCase wire value must be rejected"
4265        );
4266    }
4267
4268    #[test]
4269    fn validate_rejects_bad_trusted_proxy_entry() {
4270        let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4271            .with_trusted_proxies(["not-a-cidr"]);
4272        let err = cfg.validate().expect_err("bad CIDR");
4273        assert!(err.to_string().contains("trusted_proxies"));
4274    }
4275
4276    #[test]
4277    fn validate_accepts_cidr_and_bare_ip_proxy_entries() {
4278        let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0").with_trusted_proxies([
4279            "10.0.0.0/8",
4280            "192.0.2.1",
4281            "2001:db8::1",
4282        ]);
4283        assert!(cfg.validate().is_ok(), "CIDRs and bare IPs are accepted");
4284    }
4285
4286    #[test]
4287    fn validate_rejects_forwarded_header_without_proxies() {
4288        let cfg = McpServerConfig::new("127.0.0.1:8080", "t", "1.0.0")
4289            .with_forwarded_header(ForwardedHeaderMode::Forwarded);
4290        let err = cfg.validate().expect_err("mode without proxies");
4291        assert!(err.to_string().contains("requires trusted_proxies"));
4292    }
4293
4294    // -- origin_check_middleware --
4295
4296    /// Build a test router with origin check middleware and a simple handler.
4297    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
4298        let allowed: Arc<[String]> = Arc::from(origins);
4299        axum::Router::new()
4300            .route("/test", axum::routing::get(|| async { "ok" }))
4301            .layer(axum::middleware::from_fn(move |req, next| {
4302                let a = Arc::clone(&allowed);
4303                origin_check_middleware(a, log_request_headers, req, next)
4304            }))
4305    }
4306
4307    #[tokio::test]
4308    async fn origin_allowed_passes() {
4309        let app = origin_router(vec!["http://localhost:3000".into()], false);
4310        let req = Request::builder()
4311            .uri("/test")
4312            .header(header::ORIGIN, "http://localhost:3000")
4313            .body(Body::empty())
4314            .unwrap();
4315        let resp = app.oneshot(req).await.unwrap();
4316        assert_eq!(resp.status(), StatusCode::OK);
4317    }
4318
4319    #[tokio::test]
4320    async fn origin_rejected_returns_403() {
4321        let app = origin_router(vec!["http://localhost:3000".into()], false);
4322        let req = Request::builder()
4323            .uri("/test")
4324            .header(header::ORIGIN, "http://evil.com")
4325            .body(Body::empty())
4326            .unwrap();
4327        let resp = app.oneshot(req).await.unwrap();
4328        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
4329    }
4330
4331    #[tokio::test]
4332    async fn no_origin_header_passes() {
4333        let app = origin_router(vec!["http://localhost:3000".into()], false);
4334        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4335        let resp = app.oneshot(req).await.unwrap();
4336        assert_eq!(resp.status(), StatusCode::OK);
4337    }
4338
4339    #[tokio::test]
4340    async fn empty_allowlist_rejects_any_origin() {
4341        let app = origin_router(vec![], false);
4342        let req = Request::builder()
4343            .uri("/test")
4344            .header(header::ORIGIN, "http://anything.com")
4345            .body(Body::empty())
4346            .unwrap();
4347        let resp = app.oneshot(req).await.unwrap();
4348        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
4349    }
4350
4351    #[tokio::test]
4352    async fn empty_allowlist_passes_without_origin() {
4353        let app = origin_router(vec![], false);
4354        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4355        let resp = app.oneshot(req).await.unwrap();
4356        assert_eq!(resp.status(), StatusCode::OK);
4357    }
4358
4359    #[test]
4360    fn format_request_headers_redacts_sensitive_values() {
4361        let mut headers = axum::http::HeaderMap::new();
4362        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
4363        headers.insert("cookie", "sid=abc".parse().unwrap());
4364        headers.insert("x-request-id", "req-123".parse().unwrap());
4365
4366        let out = format_request_headers_for_log(&headers);
4367        assert!(out.contains("authorization: [REDACTED]"));
4368        assert!(out.contains("cookie: [REDACTED]"));
4369        assert!(out.contains("x-request-id: req-123"));
4370        assert!(!out.contains("secret-token"));
4371    }
4372
4373    // -- security_headers_middleware --
4374
4375    fn security_router(is_tls: bool) -> axum::Router {
4376        security_router_with(is_tls, SecurityHeadersConfig::default())
4377    }
4378
4379    fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
4380        let cfg = Arc::new(cfg);
4381        axum::Router::new()
4382            .route("/test", axum::routing::get(|| async { "ok" }))
4383            .layer(axum::middleware::from_fn(move |req, next| {
4384                let c = Arc::clone(&cfg);
4385                security_headers_middleware(is_tls, c, req, next)
4386            }))
4387    }
4388
4389    #[tokio::test]
4390    async fn security_headers_set_on_response() {
4391        let app = security_router(false);
4392        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4393        let resp = app.oneshot(req).await.unwrap();
4394        assert_eq!(resp.status(), StatusCode::OK);
4395
4396        let h = resp.headers();
4397        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
4398        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
4399        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
4400        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
4401        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
4402        assert_eq!(
4403            h.get("cross-origin-resource-policy").unwrap(),
4404            "same-origin"
4405        );
4406        assert_eq!(
4407            h.get("cross-origin-embedder-policy").unwrap(),
4408            "require-corp"
4409        );
4410        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
4411        assert!(
4412            h.get("permissions-policy")
4413                .unwrap()
4414                .to_str()
4415                .unwrap()
4416                .contains("camera=()"),
4417            "permissions-policy must restrict browser features"
4418        );
4419        assert_eq!(
4420            h.get("content-security-policy").unwrap(),
4421            "default-src 'none'; frame-ancestors 'none'"
4422        );
4423        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
4424        // No HSTS when TLS is off.
4425        assert!(h.get("strict-transport-security").is_none());
4426    }
4427
4428    #[tokio::test]
4429    async fn hsts_set_when_tls_enabled() {
4430        let app = security_router(true);
4431        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4432        let resp = app.oneshot(req).await.unwrap();
4433
4434        let hsts = resp.headers().get("strict-transport-security").unwrap();
4435        assert!(
4436            hsts.to_str().unwrap().contains("max-age=63072000"),
4437            "HSTS must set 2-year max-age"
4438        );
4439    }
4440
4441    // -- SecurityHeadersConfig validation + override semantics --
4442
4443    /// Build a minimal config with a custom SecurityHeadersConfig and
4444    /// drive it through `check()`. Returns the result so individual
4445    /// tests can assert on success or specific error messages.
4446    fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
4447        let cfg =
4448            McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
4449        cfg.check()
4450    }
4451
4452    #[test]
4453    fn security_headers_config_default_validates() {
4454        check_with_security_headers(SecurityHeadersConfig::default())
4455            .expect("default SecurityHeadersConfig must validate");
4456    }
4457
4458    #[test]
4459    fn security_headers_config_validate_accepts_empty_string() {
4460        // All twelve fields explicitly set to "" -> omit-everything mode.
4461        let h = SecurityHeadersConfig {
4462            x_content_type_options: Some(String::new()),
4463            x_frame_options: Some(String::new()),
4464            cache_control: Some(String::new()),
4465            referrer_policy: Some(String::new()),
4466            cross_origin_opener_policy: Some(String::new()),
4467            cross_origin_resource_policy: Some(String::new()),
4468            cross_origin_embedder_policy: Some(String::new()),
4469            permissions_policy: Some(String::new()),
4470            x_permitted_cross_domain_policies: Some(String::new()),
4471            content_security_policy: Some(String::new()),
4472            x_dns_prefetch_control: Some(String::new()),
4473            strict_transport_security: Some(String::new()),
4474        };
4475        check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
4476    }
4477
4478    #[test]
4479    fn security_headers_config_validate_rejects_bad_value() {
4480        // 0x07 (BEL) is not a valid HTTP header value char.
4481        let h = SecurityHeadersConfig {
4482            referrer_policy: Some("\u{0007}".into()),
4483            ..SecurityHeadersConfig::default()
4484        };
4485        let err = check_with_security_headers(h)
4486            .expect_err("control char in referrer_policy must reject");
4487        let msg = err.to_string();
4488        assert!(
4489            msg.contains("referrer_policy"),
4490            "error must name the offending field, got: {msg}"
4491        );
4492    }
4493
4494    #[test]
4495    fn security_headers_config_validate_rejects_hsts_preload() {
4496        let h = SecurityHeadersConfig {
4497            strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
4498            ..SecurityHeadersConfig::default()
4499        };
4500        let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
4501        let msg = err.to_string();
4502        assert!(
4503            msg.contains("strict_transport_security"),
4504            "error must name the field, got: {msg}"
4505        );
4506        assert!(
4507            msg.to_lowercase().contains("preload"),
4508            "error must mention `preload`, got: {msg}"
4509        );
4510    }
4511
4512    #[test]
4513    fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
4514        // Case-insensitive match.
4515        let h = SecurityHeadersConfig {
4516            strict_transport_security: Some("max-age=600; PRELOAD".into()),
4517            ..SecurityHeadersConfig::default()
4518        };
4519        check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
4520    }
4521
4522    #[tokio::test]
4523    async fn security_headers_override_honored() {
4524        // Override X-Frame-Options to SAMEORIGIN.
4525        let h = SecurityHeadersConfig {
4526            x_frame_options: Some("SAMEORIGIN".into()),
4527            ..SecurityHeadersConfig::default()
4528        };
4529        let app = security_router_with(false, h);
4530        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4531        let resp = app.oneshot(req).await.unwrap();
4532        assert_eq!(resp.status(), StatusCode::OK);
4533
4534        let xfo = resp.headers().get("x-frame-options").unwrap();
4535        assert_eq!(xfo, "SAMEORIGIN");
4536    }
4537
4538    #[tokio::test]
4539    async fn security_headers_empty_string_omits() {
4540        // Empty string on referrer-policy -> header absent.
4541        let h = SecurityHeadersConfig {
4542            referrer_policy: Some(String::new()),
4543            ..SecurityHeadersConfig::default()
4544        };
4545        let app = security_router_with(false, h);
4546        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4547        let resp = app.oneshot(req).await.unwrap();
4548        assert_eq!(resp.status(), StatusCode::OK);
4549
4550        assert!(
4551            resp.headers().get("referrer-policy").is_none(),
4552            "Some(\"\") must omit the header"
4553        );
4554        // Other defaults should still be present.
4555        assert_eq!(
4556            resp.headers().get("x-content-type-options").unwrap(),
4557            "nosniff"
4558        );
4559    }
4560
4561    #[tokio::test]
4562    async fn security_headers_hsts_only_when_tls() {
4563        // HSTS override is irrelevant when TLS is off.
4564        let h = SecurityHeadersConfig {
4565            strict_transport_security: Some("max-age=600".into()),
4566            ..SecurityHeadersConfig::default()
4567        };
4568        let app = security_router_with(false, h);
4569        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
4570        let resp = app.oneshot(req).await.unwrap();
4571        assert!(
4572            resp.headers().get("strict-transport-security").is_none(),
4573            "HSTS must remain absent on plaintext deployments even with override"
4574        );
4575    }
4576
4577    // -- oauth_token_cache_headers_middleware --
4578
4579    #[cfg(feature = "oauth")]
4580    #[tokio::test]
4581    async fn oauth_token_cache_headers_set_pragma_and_vary() {
4582        let app = axum::Router::new()
4583            .route("/token", axum::routing::post(|| async { "{}" }))
4584            .layer(axum::middleware::from_fn(
4585                oauth_token_cache_headers_middleware,
4586            ));
4587        let req = Request::builder()
4588            .method("POST")
4589            .uri("/token")
4590            .body(Body::from("{}"))
4591            .unwrap();
4592        let resp = app.oneshot(req).await.unwrap();
4593        assert_eq!(resp.status(), StatusCode::OK);
4594
4595        let h = resp.headers();
4596        assert_eq!(
4597            h.get("pragma").unwrap(),
4598            "no-cache",
4599            "RFC 6749 §5.1: token responses must set Pragma: no-cache"
4600        );
4601        let vary_values: Vec<String> = h
4602            .get_all("vary")
4603            .iter()
4604            .filter_map(|v| v.to_str().ok().map(str::to_owned))
4605            .collect();
4606        assert!(
4607            vary_values
4608                .iter()
4609                .any(|v| v.eq_ignore_ascii_case("Authorization")),
4610            "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
4611        );
4612    }
4613
4614    #[cfg(feature = "oauth")]
4615    #[tokio::test]
4616    async fn oauth_token_cache_headers_preserve_existing_vary() {
4617        // Simulates a handler/layer that already set `Vary: Accept-Encoding`
4618        // (e.g. compression). Our middleware must APPEND, not REPLACE.
4619        let app = axum::Router::new()
4620            .route(
4621                "/token",
4622                axum::routing::post(|| async {
4623                    axum::response::Response::builder()
4624                        .header("vary", "Accept-Encoding")
4625                        .body(axum::body::Body::from("{}"))
4626                        .unwrap()
4627                }),
4628            )
4629            .layer(axum::middleware::from_fn(
4630                oauth_token_cache_headers_middleware,
4631            ));
4632        let req = Request::builder()
4633            .method("POST")
4634            .uri("/token")
4635            .body(Body::empty())
4636            .unwrap();
4637        let resp = app.oneshot(req).await.unwrap();
4638
4639        let vary: Vec<String> = resp
4640            .headers()
4641            .get_all("vary")
4642            .iter()
4643            .filter_map(|v| v.to_str().ok().map(str::to_owned))
4644            .collect();
4645        assert!(
4646            vary.iter().any(|v| v.contains("Accept-Encoding")),
4647            "must preserve pre-existing Vary value, got {vary:?}"
4648        );
4649        assert!(
4650            vary.iter().any(|v| v.contains("Authorization")),
4651            "must append Authorization to Vary, got {vary:?}"
4652        );
4653    }
4654
4655    // -- version endpoint --
4656
4657    #[test]
4658    fn version_payload_contains_expected_fields() {
4659        let v = version_payload("my-server", "1.2.3");
4660        assert_eq!(v["name"], "my-server");
4661        assert_eq!(v["version"], "1.2.3");
4662        assert!(v["build_git_sha"].is_string());
4663        assert!(v["build_timestamp"].is_string());
4664        assert!(v["rust_version"].is_string());
4665        assert!(v["mcpx_version"].is_string());
4666    }
4667
4668    // -- concurrency limit layer --
4669
4670    #[tokio::test]
4671    async fn concurrency_limit_layer_composes_and_serves() {
4672        // We only assert the layer stack compiles and a single request
4673        // below the cap still succeeds. True back-pressure behaviour
4674        // requires a live HTTP server and is covered by integration tests.
4675        let app = axum::Router::new()
4676            .route("/ok", axum::routing::get(|| async { "ok" }))
4677            .layer(
4678                tower::ServiceBuilder::new()
4679                    .layer(axum::error_handling::HandleErrorLayer::new(
4680                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
4681                    ))
4682                    .layer(tower::load_shed::LoadShedLayer::new())
4683                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
4684            );
4685        let resp = app
4686            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
4687            .await
4688            .unwrap();
4689        assert_eq!(resp.status(), StatusCode::OK);
4690    }
4691
4692    // -- compression layer --
4693
4694    #[tokio::test]
4695    async fn compression_layer_gzip_encodes_response() {
4696        use tower_http::compression::Predicate as _;
4697
4698        let big_body = "a".repeat(4096);
4699        let app = axum::Router::new()
4700            .route(
4701                "/big",
4702                axum::routing::get(move || {
4703                    let body = big_body.clone();
4704                    async move { body }
4705                }),
4706            )
4707            .layer(
4708                tower_http::compression::CompressionLayer::new()
4709                    .gzip(true)
4710                    .br(true)
4711                    .compress_when(
4712                        tower_http::compression::DefaultPredicate::new()
4713                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
4714                    ),
4715            );
4716
4717        let req = Request::builder()
4718            .uri("/big")
4719            .header(header::ACCEPT_ENCODING, "gzip")
4720            .body(Body::empty())
4721            .unwrap();
4722        let resp = app.oneshot(req).await.unwrap();
4723        assert_eq!(resp.status(), StatusCode::OK);
4724        assert_eq!(
4725            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
4726            "gzip"
4727        );
4728    }
4729
4730    // -- TlsListener handshake timeout --
4731
4732    #[tokio::test]
4733    async fn tls_handshake_timeout_reaps_idle_connections() {
4734        use tokio::io::AsyncReadExt as _;
4735
4736        let _ = rustls::crypto::ring::default_provider().install_default();
4737
4738        // Self-signed cert material on disk (TlsListener::new takes paths).
4739        let key = rcgen::KeyPair::generate().expect("generate key");
4740        let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
4741            .expect("cert params")
4742            .self_signed(&key)
4743            .expect("self-signed cert");
4744        let dir = std::env::temp_dir().join(format!(
4745            "rmcp-server-kit-hs-timeout-{}",
4746            std::time::SystemTime::now()
4747                .duration_since(std::time::UNIX_EPOCH)
4748                .expect("clock after epoch")
4749                .as_nanos()
4750        ));
4751        tokio::fs::create_dir_all(&dir).await.expect("temp dir");
4752        let cert_path = dir.join("server.crt");
4753        let key_path = dir.join("server.key");
4754        tokio::fs::write(&cert_path, cert.pem())
4755            .await
4756            .expect("write cert");
4757        tokio::fs::write(&key_path, key.serialize_pem())
4758            .await
4759            .expect("write key");
4760
4761        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
4762        let tls = TlsListener::new(
4763            listener,
4764            &cert_path,
4765            &key_path,
4766            None,
4767            None,
4768            Duration::from_millis(200),
4769            8, // custom concurrency cap: proves the plumbing end-to-end
4770        )
4771        .expect("tls listener");
4772        let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
4773
4774        // Connect and send NOTHING: the handshake worker must time out
4775        // after 200ms and drop the stream, which the client observes as
4776        // EOF or a reset well within the 2s deadline.
4777        let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
4778        let mut buf = [0_u8; 16];
4779        let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
4780            .await
4781            .expect("server must reap the idle handshake within its timeout");
4782        match read {
4783            Ok(0) | Err(_) => {} // EOF or reset: connection was dropped.
4784            Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
4785        }
4786
4787        drop(tls);
4788    }
4789}