Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::{IpAddr, SocketAddr},
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{
12    body::Body,
13    extract::{ConnectInfo, Request},
14    middleware::Next,
15    response::IntoResponse,
16};
17use rmcp::{
18    ServerHandler,
19    transport::streamable_http_server::{
20        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21    },
22};
23use rustls::RootCertStore;
24use tokio::{
25    net::TcpListener,
26    sync::{Semaphore, mpsc},
27};
28use tokio_util::sync::CancellationToken;
29
30use crate::{
31    auth::{
32        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
33        build_rate_limiter, extract_mtls_identity,
34    },
35    bounded_limiter::BoundedKeyedLimiter,
36    error::McpxError,
37    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
38    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
39};
40
41/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
42/// at the public API boundary, flattening the chain via the alternate
43/// formatter so callers see the full causal path.
44#[allow(
45    clippy::needless_pass_by_value,
46    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
47)]
48fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
49    McpxError::Startup(format!("{e:#}"))
50}
51
52/// Map a `std::io::Error` produced during server startup into a public
53/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
54/// `From` impl here because startup-phase IO errors (bind, listener) are
55/// semantically distinct from request-time IO errors and should surface
56/// the originating operation in the message.
57#[allow(
58    clippy::needless_pass_by_value,
59    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
60)]
61fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
62    McpxError::Startup(format!("{op}: {e}"))
63}
64
65/// Async readiness check callback for the `/readyz` endpoint.
66///
67/// Returns a JSON object with at least a `"ready"` boolean.
68/// When `ready` is false, the endpoint returns HTTP 503.
69pub type ReadinessCheck =
70    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
71
72/// Direct socket peer address of the current HTTP/TLS connection.
73///
74/// Inserted as a request extension into every request served by [`serve`] —
75/// on both the plain and the TLS listener — and extractable in any axum
76/// handler, including routes mounted via
77/// [`McpServerConfig::with_extra_router`] (which bypass auth/RBAC and
78/// therefore often need the peer address for their own protection, e.g.
79/// per-IP rate limiting).
80///
81/// The same address is also mirrored into
82/// [`axum::extract::ConnectInfo`]`<SocketAddr>` on the TLS listener, so
83/// third-party middleware that expects the stock axum extension (e.g.
84/// per-IP rate-limit key extractors) works unmodified under TLS.
85///
86/// # Semantics
87///
88/// - **Direct peer only.** This is the socket's remote address. Behind an
89///   L4/L7 proxy or load balancer it is the proxy's address; the framework
90///   performs **no** `X-Forwarded-For` / `Forwarded` interpretation.
91/// - **Available on HTTP and TLS** transports alike ([`serve`]).
92/// - **Absent under [`serve_stdio`]** — a stdio session has no network
93///   peer (stdio bypasses the HTTP stack entirely).
94/// - The separate Prometheus metrics listener (feature `metrics`) is a
95///   different router and does not carry this extension.
96///
97/// # Privacy
98///
99/// `PeerAddr` exposes raw peer network metadata. The framework deliberately
100/// never logs it on its own; whether to log or persist peer addresses is
101/// application policy.
102///
103/// # Example
104///
105/// ```no_run
106/// use axum::{Router, routing::get};
107/// use rmcp_server_kit::transport::{McpServerConfig, PeerAddr};
108///
109/// async fn whoami(peer: PeerAddr) -> String {
110///     peer.addr.ip().to_string()
111/// }
112///
113/// let _config = McpServerConfig::new("127.0.0.1:8443", "my-server", "1.0.0")
114///     .with_extra_router(Router::new().route("/whoami", get(whoami)));
115/// ```
116#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
117#[non_exhaustive]
118pub struct PeerAddr {
119    /// Direct socket peer of this connection.
120    pub addr: SocketAddr,
121}
122
123impl PeerAddr {
124    /// Construct a new [`PeerAddr`]. Framework-internal: downstream code
125    /// receives `PeerAddr` via request extensions and never constructs it.
126    #[must_use]
127    pub(crate) const fn new(addr: SocketAddr) -> Self {
128        Self { addr }
129    }
130}
131
132/// Extract [`PeerAddr`] from request extensions.
133///
134/// # Rejection
135///
136/// Responds `500 Internal Server Error` when the extension is missing.
137/// A missing `PeerAddr` means the handler is not running under [`serve`]
138/// (e.g. the router was mounted on a hand-rolled listener) — a wiring
139/// bug, not a client error.
140impl<S: Send + Sync> axum::extract::FromRequestParts<S> for PeerAddr {
141    type Rejection = (axum::http::StatusCode, &'static str);
142
143    async fn from_request_parts(
144        parts: &mut axum::http::request::Parts,
145        _state: &S,
146    ) -> Result<Self, Self::Rejection> {
147        parts.extensions.get::<Self>().copied().ok_or((
148            axum::http::StatusCode::INTERNAL_SERVER_ERROR,
149            "peer address unavailable: not running under rmcp-server-kit serve()",
150        ))
151    }
152}
153
154/// Per-header overrides for the OWASP security headers emitted by the
155/// global response middleware.
156///
157/// Each field follows a three-state semantic:
158///
159/// | Value         | Behaviour                                                |
160/// |---------------|----------------------------------------------------------|
161/// | `None`        | Use the built-in default (current behaviour).            |
162/// | `Some("")`    | **Omit** the header entirely from responses.             |
163/// | `Some(value)` | Emit `header: value`. Validated at config-load time.     |
164///
165/// All non-empty values are validated via
166/// [`axum::http::HeaderValue::from_str`] inside
167/// [`McpServerConfig::validate`]; invalid values fail fast before the
168/// server starts accepting traffic.
169///
170/// `Strict-Transport-Security` has an additional rule: the substring
171/// `preload` (case-insensitive) is rejected. Operators who want to
172/// commit to the HSTS preload list must do so via a future explicit
173/// builder method, not by smuggling it through this knob.
174#[derive(Debug, Clone, Default)]
175#[non_exhaustive]
176pub struct SecurityHeadersConfig {
177    /// Override for `X-Content-Type-Options`. Default: `nosniff`.
178    pub x_content_type_options: Option<String>,
179    /// Override for `X-Frame-Options`. Default: `deny`.
180    pub x_frame_options: Option<String>,
181    /// Override for `Cache-Control`. Default: `no-store, max-age=0`.
182    pub cache_control: Option<String>,
183    /// Override for `Referrer-Policy`. Default: `no-referrer`.
184    pub referrer_policy: Option<String>,
185    /// Override for `Cross-Origin-Opener-Policy`. Default: `same-origin`.
186    pub cross_origin_opener_policy: Option<String>,
187    /// Override for `Cross-Origin-Resource-Policy`. Default: `same-origin`.
188    pub cross_origin_resource_policy: Option<String>,
189    /// Override for `Cross-Origin-Embedder-Policy`. Default: `require-corp`.
190    pub cross_origin_embedder_policy: Option<String>,
191    /// Override for `Permissions-Policy`. Default:
192    /// `accelerometer=(), camera=(), geolocation=(), microphone=()`.
193    pub permissions_policy: Option<String>,
194    /// Override for `X-Permitted-Cross-Domain-Policies`. Default: `none`.
195    pub x_permitted_cross_domain_policies: Option<String>,
196    /// Override for `Content-Security-Policy`. Default:
197    /// `default-src 'none'; frame-ancestors 'none'`.
198    pub content_security_policy: Option<String>,
199    /// Override for `X-DNS-Prefetch-Control`. Default: `off`.
200    pub x_dns_prefetch_control: Option<String>,
201    /// Override for `Strict-Transport-Security`. Default (TLS only):
202    /// `max-age=63072000; includeSubDomains`. Only emitted when TLS is
203    /// active; the override is ignored on plaintext deployments. The
204    /// substring `preload` (any case) is rejected by the validator.
205    pub strict_transport_security: Option<String>,
206}
207
208/// Configuration for the MCP server.
209#[allow(
210    missing_debug_implementations,
211    reason = "contains callback/trait objects that don't impl Debug"
212)]
213#[allow(
214    clippy::struct_excessive_bools,
215    reason = "server configuration naturally has many boolean feature flags"
216)]
217#[non_exhaustive]
218pub struct McpServerConfig {
219    /// Socket address the MCP HTTP server binds to.
220    #[deprecated(
221        since = "0.13.0",
222        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
223    )]
224    pub bind_addr: String,
225    /// Server name advertised via MCP `initialize`.
226    #[deprecated(
227        since = "0.13.0",
228        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
229    )]
230    pub name: String,
231    /// Server version advertised via MCP `initialize`.
232    #[deprecated(
233        since = "0.13.0",
234        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
235    )]
236    pub version: String,
237    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
238    #[deprecated(
239        since = "0.13.0",
240        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
241    )]
242    pub tls_cert_path: Option<PathBuf>,
243    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
244    #[deprecated(
245        since = "0.13.0",
246        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
247    )]
248    pub tls_key_path: Option<PathBuf>,
249    /// Optional authentication config. When `Some` and `enabled`, auth
250    /// is enforced on `/mcp`. `/healthz` is always open.
251    #[deprecated(
252        since = "0.13.0",
253        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
254    )]
255    pub auth: Option<AuthConfig>,
256    /// Optional RBAC policy. When present and enabled, tool calls are
257    /// checked against the policy after authentication.
258    #[deprecated(
259        since = "0.13.0",
260        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
261    )]
262    pub rbac: Option<Arc<RbacPolicy>>,
263    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
264    /// When empty and `public_url` is set, the origin is auto-derived from
265    /// the public URL. When both are empty, only requests with no Origin
266    /// header are accepted.
267    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
268    #[deprecated(
269        since = "0.13.0",
270        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
271    )]
272    pub allowed_origins: Vec<String>,
273    /// Maximum tool invocations per source IP per minute.
274    /// When set, enforced on every `tools/call` request.
275    #[deprecated(
276        since = "0.13.0",
277        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
278    )]
279    pub tool_rate_limit: Option<u32>,
280    /// Maximum requests per source IP per minute for routes merged via
281    /// [`with_extra_router`](Self::with_extra_router). Opt-in: `None`
282    /// (the default) installs no limiter. Startup-only (not
283    /// hot-reloadable via [`ReloadHandle`]).
284    ///
285    /// Keyed by the **direct socket peer** ([`PeerAddr`] semantics — no
286    /// `X-Forwarded-For` interpretation): behind a reverse proxy all
287    /// clients share the proxy's bucket, and IPv6 single-host address
288    /// rotation can evade per-IP keying. Treat this as an abuse speed
289    /// bump for unauthenticated application endpoints, not tenant
290    /// isolation. On limit: HTTP 429 with a plain-text body, matching
291    /// the tool/auth limiters.
292    #[deprecated(
293        since = "1.11.0",
294        note = "use McpServerConfig::with_extra_route_rate_limit(); direct field access will become pub(crate) in a future major release"
295    )]
296    pub extra_route_rate_limit: Option<u32>,
297    /// Optional readiness probe for `/readyz`.
298    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
299    #[deprecated(
300        since = "0.13.0",
301        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
302    )]
303    pub readiness_check: Option<ReadinessCheck>,
304    /// Maximum request body size in bytes. Default: 1 MiB.
305    /// Protects against oversized payloads causing OOM.
306    #[deprecated(
307        since = "0.13.0",
308        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
309    )]
310    pub max_request_body: usize,
311    /// Request processing timeout. Default: 120s.
312    /// Requests exceeding this duration receive 408 Request Timeout.
313    #[deprecated(
314        since = "0.13.0",
315        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
316    )]
317    pub request_timeout: Duration,
318    /// Graceful shutdown timeout. Default: 30s.
319    /// After the shutdown signal, in-flight requests have this long to finish.
320    #[deprecated(
321        since = "0.13.0",
322        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
323    )]
324    pub shutdown_timeout: Duration,
325    /// Idle timeout for MCP sessions. Sessions with no activity for this
326    /// duration are closed automatically. Default: 20 minutes.
327    #[deprecated(
328        since = "0.13.0",
329        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
330    )]
331    pub session_idle_timeout: Duration,
332    /// Interval for SSE keep-alive pings. Prevents proxies and load
333    /// balancers from killing idle connections. Default: 15 seconds.
334    #[deprecated(
335        since = "0.13.0",
336        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
337    )]
338    pub sse_keep_alive: Duration,
339    /// Callback invoked once the server is built, delivering a
340    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
341    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
342    #[deprecated(
343        since = "0.13.0",
344        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
345    )]
346    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
347    /// Additional application-specific routes merged into the top-level
348    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
349    /// so the application is responsible for its own auth on them.
350    /// Handlers can extract [`PeerAddr`] (or
351    /// [`axum::extract::ConnectInfo`]`<SocketAddr>` for third-party
352    /// middleware compatibility) regardless of whether TLS is enabled.
353    #[deprecated(
354        since = "0.13.0",
355        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
356    )]
357    pub extra_router: Option<axum::Router>,
358    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
359    /// When set, OAuth metadata endpoints advertise this URL instead of
360    /// the listen address. Required when binding `0.0.0.0` behind a
361    /// reverse proxy or inside a container.
362    #[deprecated(
363        since = "0.13.0",
364        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
365    )]
366    pub public_url: Option<String>,
367    /// Log inbound HTTP request headers at DEBUG level.
368    /// Sensitive values remain redacted.
369    #[deprecated(
370        since = "0.13.0",
371        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
372    )]
373    pub log_request_headers: bool,
374    /// Enable gzip/br response compression on MCP responses.
375    /// Defaults to `false` to preserve existing behaviour.
376    #[deprecated(
377        since = "0.13.0",
378        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
379    )]
380    pub compression_enabled: bool,
381    /// Minimum response body size (in bytes) before compression kicks in.
382    /// Only used when `compression_enabled` is true. Default: 1024.
383    #[deprecated(
384        since = "0.13.0",
385        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
386    )]
387    pub compression_min_size: u16,
388    /// Global cap on in-flight HTTP requests across the whole server.
389    /// When `Some`, requests over the cap receive 503 Service Unavailable
390    /// via `tower::load_shed`. Default: `None` (unlimited).
391    #[deprecated(
392        since = "0.13.0",
393        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
394    )]
395    pub max_concurrent_requests: Option<usize>,
396    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
397    /// configured and `enabled`. Default: `false`.
398    #[deprecated(
399        since = "0.13.0",
400        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
401    )]
402    pub admin_enabled: bool,
403    /// RBAC role required to access admin endpoints. Default: `"admin"`.
404    #[deprecated(
405        since = "0.13.0",
406        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
407    )]
408    pub admin_role: String,
409    /// Enable Prometheus metrics endpoint on a separate listener.
410    /// Requires the `metrics` crate feature.
411    #[cfg(feature = "metrics")]
412    #[deprecated(
413        since = "0.13.0",
414        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
415    )]
416    pub metrics_enabled: bool,
417    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
418    #[cfg(feature = "metrics")]
419    #[deprecated(
420        since = "0.13.0",
421        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
422    )]
423    pub metrics_bind: String,
424    /// Per-header overrides for the OWASP security headers emitted by
425    /// the global response middleware. See [`SecurityHeadersConfig`]
426    /// for the three-state semantic and validation rules.
427    #[deprecated(
428        since = "1.5.0",
429        note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
430    )]
431    pub security_headers: SecurityHeadersConfig,
432    /// Per-handshake deadline on the TLS accept path. Idle or slow-loris
433    /// connections are dropped once it elapses. Default: 10s.
434    ///
435    /// Startup-only: bound at listener construction, NOT hot-reloadable
436    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
437    #[deprecated(
438        since = "1.9.0",
439        note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
440    )]
441    pub tls_handshake_timeout: Duration,
442    /// Cap on concurrently in-flight TLS handshakes. At saturation the
443    /// acceptor stops pulling new connections from the kernel backlog
444    /// (backpressure) instead of accepting and dropping. Default: 256.
445    ///
446    /// Startup-only: bound at listener construction, NOT hot-reloadable
447    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
448    #[deprecated(
449        since = "1.9.0",
450        note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
451    )]
452    pub max_concurrent_tls_handshakes: usize,
453}
454
455/// Marker that wraps a value proven to satisfy its validation
456/// contract.
457///
458/// The only way to obtain `Validated<McpServerConfig>` is by calling
459/// [`McpServerConfig::validate`], which is the contract enforced at
460/// the type level by [`serve`] and [`serve_with_listener`]. The
461/// inner field is private, so downstream code cannot bypass
462/// validation by hand-constructing the wrapper.
463///
464/// Use [`Validated::as_inner`] for read-only borrowing. To mutate,
465/// recover the raw value with [`Validated::into_inner`] and
466/// re-validate.
467///
468/// # Example
469///
470/// ```no_run
471/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
472/// use rmcp::handler::server::ServerHandler;
473/// use rmcp::model::{ServerCapabilities, ServerInfo};
474///
475/// #[derive(Clone)]
476/// struct H;
477/// impl ServerHandler for H {
478///     fn get_info(&self) -> ServerInfo {
479///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
480///     }
481/// }
482///
483/// # async fn example() -> rmcp_server_kit::Result<()> {
484/// let config: Validated<McpServerConfig> =
485///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
486/// serve(config, || H).await
487/// # }
488/// ```
489///
490/// Forgetting `.validate()?` is a compile error:
491///
492/// ```compile_fail
493/// use rmcp_server_kit::transport::{McpServerConfig, serve};
494/// use rmcp::handler::server::ServerHandler;
495/// use rmcp::model::{ServerCapabilities, ServerInfo};
496///
497/// #[derive(Clone)]
498/// struct H;
499/// impl ServerHandler for H {
500///     fn get_info(&self) -> ServerInfo {
501///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
502///     }
503/// }
504///
505/// # async fn example() -> rmcp_server_kit::Result<()> {
506/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
507/// // Missing `.validate()?` -> mismatched types: expected
508/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
509/// serve(config, || H).await
510/// # }
511/// ```
512#[allow(
513    missing_debug_implementations,
514    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
515)]
516pub struct Validated<T>(T);
517
518impl<T> std::fmt::Debug for Validated<T> {
519    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520        f.debug_struct("Validated").finish_non_exhaustive()
521    }
522}
523
524impl<T> Validated<T> {
525    /// Borrow the inner value.
526    #[must_use]
527    pub fn as_inner(&self) -> &T {
528        &self.0
529    }
530
531    /// Recover the raw value, discarding the validation proof.
532    ///
533    /// Re-validate before re-using the value with [`serve`] or
534    /// [`serve_with_listener`].
535    #[must_use]
536    pub fn into_inner(self) -> T {
537        self.0
538    }
539}
540
541#[allow(
542    deprecated,
543    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
544)]
545impl McpServerConfig {
546    /// Create a new server configuration with the given bind address,
547    /// server name, and version. All other fields use safe defaults.
548    ///
549    /// Use the chainable `with_*` / `enable_*` builder methods to
550    /// customize. Call [`McpServerConfig::validate`] to obtain a
551    /// [`Validated<McpServerConfig>`] proof token, which is required by
552    /// [`serve`] and [`serve_with_listener`].
553    #[must_use]
554    pub fn new(
555        bind_addr: impl Into<String>,
556        name: impl Into<String>,
557        version: impl Into<String>,
558    ) -> Self {
559        Self {
560            bind_addr: bind_addr.into(),
561            name: name.into(),
562            version: version.into(),
563            tls_cert_path: None,
564            tls_key_path: None,
565            auth: None,
566            rbac: None,
567            allowed_origins: Vec::new(),
568            tool_rate_limit: None,
569            readiness_check: None,
570            max_request_body: 1024 * 1024,
571            request_timeout: Duration::from_mins(2),
572            shutdown_timeout: Duration::from_secs(30),
573            session_idle_timeout: Duration::from_mins(20),
574            sse_keep_alive: Duration::from_secs(15),
575            on_reload_ready: None,
576            extra_router: None,
577            public_url: None,
578            log_request_headers: false,
579            compression_enabled: false,
580            compression_min_size: 1024,
581            max_concurrent_requests: None,
582            admin_enabled: false,
583            admin_role: "admin".to_owned(),
584            #[cfg(feature = "metrics")]
585            metrics_enabled: false,
586            #[cfg(feature = "metrics")]
587            metrics_bind: "127.0.0.1:9090".into(),
588            security_headers: SecurityHeadersConfig::default(),
589            tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
590            max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
591            extra_route_rate_limit: None,
592        }
593    }
594
595    // ---------------------------------------------------------------
596    // Builder methods (fluent, consume + return self).
597    //
598    // Each method is `#[must_use]` because dropping the returned
599    // `McpServerConfig` discards the configuration change.
600    // ---------------------------------------------------------------
601
602    /// Attach an authentication configuration. Required for
603    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
604    #[must_use]
605    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
606        self.auth = Some(auth);
607        self
608    }
609
610    /// Override one or more of the OWASP security headers emitted on
611    /// every response. See [`SecurityHeadersConfig`] for the three-state
612    /// semantic (`None` = default, `Some("")` = omit, `Some(v)` =
613    /// override). Values are validated by [`Self::validate`].
614    #[must_use]
615    pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
616        self.security_headers = headers;
617        self
618    }
619
620    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
621    /// final port is only known after pre-binding an ephemeral listener
622    /// (tests, dynamic-port deployments).
623    #[must_use]
624    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
625        self.bind_addr = addr.into();
626        self
627    }
628
629    /// Attach an RBAC policy. Tool calls are checked against the policy
630    /// after authentication.
631    #[must_use]
632    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
633        self.rbac = Some(rbac);
634        self
635    }
636
637    /// Configure TLS by providing the certificate and private key paths
638    /// (PEM). Both must be readable at startup. Without this call, the
639    /// server runs plain HTTP.
640    #[must_use]
641    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
642        self.tls_cert_path = Some(cert_path.into());
643        self.tls_key_path = Some(key_path.into());
644        self
645    }
646
647    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
648    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
649    /// a container so OAuth metadata and auto-derived origins resolve correctly.
650    #[must_use]
651    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
652        self.public_url = Some(url.into());
653        self
654    }
655
656    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
657    /// When empty and [`with_public_url`](Self::with_public_url) is set,
658    /// the origin is auto-derived.
659    #[must_use]
660    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
661    where
662        I: IntoIterator<Item = S>,
663        S: Into<String>,
664    {
665        self.allowed_origins = origins.into_iter().map(Into::into).collect();
666        self
667    }
668
669    /// Merge an additional axum router at the top level. Routes added
670    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
671    /// for its own protection.
672    ///
673    /// To support that protection (e.g. per-IP rate limiting on
674    /// unauthenticated endpoints), every request served by [`serve`]
675    /// carries the client peer address regardless of whether TLS is
676    /// enabled: extract the framework-owned [`PeerAddr`] in your
677    /// handlers, or rely on [`axum::extract::ConnectInfo`]`<SocketAddr>`
678    /// for stock third-party middleware (e.g. per-IP rate-limit key
679    /// extractors). Neither extension exists under [`serve_stdio`],
680    /// which has no network peer.
681    #[must_use]
682    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
683        self.extra_router = Some(router);
684        self
685    }
686
687    /// Install an async readiness probe for `/readyz`. Without this call,
688    /// `/readyz` mirrors `/healthz` (always 200 OK).
689    #[must_use]
690    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
691        self.readiness_check = Some(check);
692        self
693    }
694
695    /// Override the maximum request body (bytes). Must be `> 0`.
696    /// Default: 1 MiB.
697    #[must_use]
698    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
699        self.max_request_body = bytes;
700        self
701    }
702
703    /// Override the per-request processing timeout. Default: 2 minutes.
704    #[must_use]
705    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
706        self.request_timeout = timeout;
707        self
708    }
709
710    /// Override the graceful shutdown grace period. Default: 30 seconds.
711    #[must_use]
712    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
713        self.shutdown_timeout = timeout;
714        self
715    }
716
717    /// Override the MCP session idle timeout. Default: 20 minutes.
718    #[must_use]
719    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
720        self.session_idle_timeout = timeout;
721        self
722    }
723
724    /// Override the SSE keep-alive interval. Default: 15 seconds.
725    #[must_use]
726    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
727        self.sse_keep_alive = interval;
728        self
729    }
730
731    /// Cap the global number of in-flight HTTP requests via
732    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
733    /// Default: unlimited.
734    #[must_use]
735    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
736        self.max_concurrent_requests = Some(limit);
737        self
738    }
739
740    /// Override the per-handshake deadline on the TLS accept path.
741    /// Idle or slow-loris connections are dropped once it elapses.
742    /// Default: 10s. Must be greater than zero.
743    ///
744    /// Startup-only: the value is bound at listener construction and is
745    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
746    /// is configured via [`Self::with_tls`].
747    #[must_use]
748    pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
749        self.tls_handshake_timeout = timeout;
750        self
751    }
752
753    /// Override the cap on concurrently in-flight TLS handshakes. At
754    /// saturation the acceptor stops pulling new connections from the
755    /// kernel backlog (backpressure) instead of accepting and dropping.
756    /// Default: 256. Must be greater than zero.
757    ///
758    /// Startup-only: the value is bound at listener construction and is
759    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
760    /// is configured via [`Self::with_tls`].
761    #[must_use]
762    pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
763        self.max_concurrent_tls_handshakes = limit;
764        self
765    }
766
767    /// Cap tool invocations per source IP per minute. Enforced on every
768    /// `tools/call` request.
769    #[must_use]
770    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
771        self.tool_rate_limit = Some(per_minute);
772        self
773    }
774
775    /// Cap requests per source IP per minute on routes merged via
776    /// [`with_extra_router`](Self::with_extra_router) — the natural
777    /// protection for unauthenticated application endpoints (OAuth
778    /// callbacks, registration, …) that bypass auth/RBAC by design.
779    ///
780    /// Must be greater than zero (validated by
781    /// [`validate`](Self::validate)). Startup-only. See the
782    /// `extra_route_rate_limit` field docs for keying semantics and
783    /// caveats (direct peer only, IPv6 rotation, proxy collapse,
784    /// bounded-memory shared-fate under key spray).
785    #[must_use]
786    pub fn with_extra_route_rate_limit(mut self, per_minute: u32) -> Self {
787        self.extra_route_rate_limit = Some(per_minute);
788        self
789    }
790
791    /// Register a callback that receives the [`ReloadHandle`] after the
792    /// server is built. Use it to wire SIGHUP-style hot reloads of API
793    /// keys and RBAC policy.
794    #[must_use]
795    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
796    where
797        F: FnOnce(ReloadHandle) + Send + 'static,
798    {
799        self.on_reload_ready = Some(Box::new(callback));
800        self
801    }
802
803    /// Enable gzip/brotli response compression on MCP responses.
804    /// `min_size` is the smallest body size (bytes) eligible for
805    /// compression. Default min size: 1024.
806    #[must_use]
807    pub fn enable_compression(mut self, min_size: u16) -> Self {
808        self.compression_enabled = true;
809        self.compression_min_size = min_size;
810        self
811    }
812
813    /// Enable `/admin/*` diagnostic endpoints. Requires
814    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
815    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
816    /// role gate (default: `"admin"`).
817    #[must_use]
818    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
819        self.admin_enabled = true;
820        self.admin_role = role.into();
821        self
822    }
823
824    /// Log inbound HTTP request headers at DEBUG level. Sensitive
825    /// values remain redacted by the logging layer.
826    #[must_use]
827    pub fn enable_request_header_logging(mut self) -> Self {
828        self.log_request_headers = true;
829        self
830    }
831
832    /// Enable the Prometheus metrics listener on `bind` (e.g.
833    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
834    #[cfg(feature = "metrics")]
835    #[must_use]
836    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
837        self.metrics_enabled = true;
838        self.metrics_bind = bind.into();
839        self
840    }
841
842    /// Validate the configuration and consume `self`, returning a
843    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
844    /// and [`serve_with_listener`]. This is the only way to construct
845    /// `Validated<McpServerConfig>`, so the type system guarantees
846    /// validation has run before the server starts.
847    ///
848    /// Checks:
849    ///
850    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
851    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
852    ///    be unset.
853    /// 3. `bind_addr` must parse as a [`SocketAddr`].
854    /// 4. `public_url`, when set, must start with `http://` or `https://`.
855    /// 5. Each entry in `allowed_origins` must start with `http://` or
856    ///    `https://`.
857    /// 6. `max_request_body` must be greater than zero.
858    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
859    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
860    ///    `proxy.token_url`, `proxy.introspection_url`,
861    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
862    ///    and use the `https` scheme. Set
863    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
864    ///    targets (strongly discouraged in production - see the field-level
865    ///    docs for the threat model).
866    ///
867    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
868    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
869    ///
870    /// # Errors
871    ///
872    /// Returns [`McpxError::Config`] with a human-readable message on
873    /// the first validation failure.
874    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
875        self.check()?;
876        Ok(Validated(self))
877    }
878
879    /// Run the validation checks without consuming `self`. Used by
880    /// internal call sites (e.g. tests) that need to inspect a config
881    /// without taking ownership.
882    fn check(&self) -> Result<(), McpxError> {
883        // 1. admin <-> auth dependency. Mirrors the runtime check in
884        //    `build_app_router`: admin endpoints require an auth state,
885        //    which is built only when `auth` is `Some` *and* `enabled`.
886        if self.admin_enabled {
887            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
888            if !auth_enabled {
889                return Err(McpxError::Config(
890                    "admin_enabled=true requires auth to be configured and enabled".into(),
891                ));
892            }
893        }
894
895        // 2. TLS cert / key must be paired
896        match (&self.tls_cert_path, &self.tls_key_path) {
897            (Some(_), None) => {
898                return Err(McpxError::Config(
899                    "tls_cert_path is set but tls_key_path is missing".into(),
900                ));
901            }
902            (None, Some(_)) => {
903                return Err(McpxError::Config(
904                    "tls_key_path is set but tls_cert_path is missing".into(),
905                ));
906            }
907            _ => {}
908        }
909
910        // 3. bind_addr parses
911        if self.bind_addr.parse::<SocketAddr>().is_err() {
912            return Err(McpxError::Config(format!(
913                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
914                self.bind_addr
915            )));
916        }
917
918        // 4. public_url scheme
919        if let Some(ref url) = self.public_url
920            && !(url.starts_with("http://") || url.starts_with("https://"))
921        {
922            return Err(McpxError::Config(format!(
923                "public_url {url:?} must start with http:// or https://"
924            )));
925        }
926
927        // 5. allowed_origins scheme
928        for origin in &self.allowed_origins {
929            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
930                return Err(McpxError::Config(format!(
931                    "allowed_origins entry {origin:?} must start with http:// or https://"
932                )));
933            }
934        }
935
936        // 6. max_request_body > 0
937        if self.max_request_body == 0 {
938            return Err(McpxError::Config(
939                "max_request_body must be greater than zero".into(),
940            ));
941        }
942
943        // 6b. extra_route_rate_limit, when set, must be > 0. Unlike the
944        // legacy tool_rate_limit (which clamps 0 to its default at
945        // construction), new knobs fail fast on nonsensical values.
946        if self.extra_route_rate_limit == Some(0) {
947            return Err(McpxError::Config(
948                "extra_route_rate_limit must be greater than zero".into(),
949            ));
950        }
951
952        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
953        #[cfg(feature = "oauth")]
954        if let Some(auth_cfg) = &self.auth
955            && let Some(oauth_cfg) = &auth_cfg.oauth
956        {
957            oauth_cfg.validate()?;
958        }
959
960        // 8. Security-header overrides parse as valid HTTP header values,
961        //    and HSTS does not smuggle in a `preload` directive.
962        validate_security_headers(&self.security_headers)?;
963
964        // 9. max_concurrent_requests must be > 0 when set. Zero would
965        //    deadlock the global concurrency limiter and reject every
966        //    request. Mirrors the TOML-side check in `src/config.rs`.
967        if let Some(0) = self.max_concurrent_requests {
968            return Err(McpxError::Config(
969                "max_concurrent_requests must be greater than zero when set".into(),
970            ));
971        }
972
973        // 10. Auth rate-limit `max_tracked_keys` must be > 0. A zero cap
974        //     would force `BoundedKeyedLimiter` to evict on every insert
975        //     and effectively disable rate limiting.
976        if let Some(auth_cfg) = &self.auth
977            && let Some(rl) = &auth_cfg.rate_limit
978            && rl.max_tracked_keys == 0
979        {
980            return Err(McpxError::Config(
981                "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
982            ));
983        }
984
985        // 11. tls_handshake_timeout must be > 0. A zero deadline would
986        //     reap every handshake before it could complete, rejecting
987        //     all TLS connections. Mirrors the TOML-side check in
988        //     `src/config.rs`.
989        if self.tls_handshake_timeout == Duration::ZERO {
990            return Err(McpxError::Config(
991                "tls_handshake_timeout must be greater than zero".into(),
992            ));
993        }
994
995        // 12. max_concurrent_tls_handshakes must be > 0. A zero-permit
996        //     semaphore would never admit a handshake, deadlocking the
997        //     TLS accept path. Mirrors the TOML-side check in
998        //     `src/config.rs`.
999        if self.max_concurrent_tls_handshakes == 0 {
1000            return Err(McpxError::Config(
1001                "max_concurrent_tls_handshakes must be greater than zero".into(),
1002            ));
1003        }
1004
1005        Ok(())
1006    }
1007}
1008
1009/// Handle for hot-reloading server configuration without restart.
1010///
1011/// Obtained via [`McpServerConfig::on_reload_ready`].
1012/// All swap operations are lock-free and wait-free -- in-flight requests
1013/// finish with the old values while new requests see the update immediately.
1014#[allow(
1015    missing_debug_implementations,
1016    reason = "contains Arc<AuthState> with non-Debug fields"
1017)]
1018pub struct ReloadHandle {
1019    auth: Option<Arc<AuthState>>,
1020    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
1021    crl_set: Option<Arc<CrlSet>>,
1022}
1023
1024impl ReloadHandle {
1025    /// Atomically replace the API key list used by the auth middleware.
1026    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
1027        if let Some(ref auth) = self.auth {
1028            auth.reload_keys(keys);
1029        }
1030    }
1031
1032    /// Atomically replace the RBAC policy used by the RBAC middleware.
1033    pub fn reload_rbac(&self, policy: RbacPolicy) {
1034        if let Some(ref rbac) = self.rbac {
1035            rbac.store(Arc::new(policy));
1036            tracing::info!("RBAC policy reloaded");
1037        }
1038    }
1039
1040    /// Force an immediate refresh of all cached mTLS CRLs.
1041    ///
1042    /// # Errors
1043    ///
1044    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
1045    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
1046        let Some(ref crl_set) = self.crl_set else {
1047            return Err(McpxError::Config(
1048                "CRL refresh requested but mTLS CRL support is not configured".into(),
1049            ));
1050        };
1051
1052        crl_set.force_refresh().await
1053    }
1054}
1055
1056/// Generic MCP HTTP server.
1057///
1058/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
1059/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
1060/// with TLS (rustls). Optionally supports mTLS client certificate auth.
1061///
1062/// # Errors
1063///
1064/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
1065/// or the server fails.
1066// NOTE: cognitive complexity reduced from 111/25 to 83/25 by
1067// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
1068// Remaining flow is a linear router builder: middleware layering, feature-
1069// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
1070// would require threading many `&mut Router` helpers and hurt readability
1071// of the layer order (which is security-relevant and must stay visible).
1072#[allow(
1073    clippy::too_many_lines,
1074    clippy::cognitive_complexity,
1075    reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
1076)]
1077/// Internal bundle of values produced by [`build_app_router`] and
1078/// consumed by [`serve`] / [`serve_with_listener`] when driving the
1079/// HTTP listener.
1080struct AppRunParams {
1081    /// TLS cert/key paths when TLS is configured.
1082    tls_paths: Option<(PathBuf, PathBuf)>,
1083    /// Per-handshake deadline on the TLS accept path.
1084    tls_handshake_timeout: Duration,
1085    /// Cap on concurrently in-flight TLS handshakes.
1086    max_concurrent_tls_handshakes: usize,
1087    /// mTLS configuration when mutual-TLS auth is enabled.
1088    mtls_config: Option<MtlsConfig>,
1089    /// Graceful shutdown drain window.
1090    shutdown_timeout: Duration,
1091    /// Shared auth state used by hot-reload callbacks.
1092    auth_state: Option<Arc<AuthState>>,
1093    /// Hot-reloadable RBAC state used by reload callbacks.
1094    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1095    /// Optional callback that receives the final [`ReloadHandle`].
1096    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1097    /// Server-internal cancellation token. Cancelled by [`run_server`]
1098    /// once the shutdown trigger fires (so rmcp's child token also
1099    /// fires, terminating in-flight MCP sessions).
1100    ct: CancellationToken,
1101    /// `"http"` or `"https"` -- used only for boot-time logging.
1102    scheme: &'static str,
1103    /// Server name -- used only for boot-time logging.
1104    name: String,
1105}
1106
1107/// Build the full application axum [`axum::Router`] (MCP route +
1108/// middleware stack + admin + OAuth + health endpoints + security
1109/// headers + CORS + compression + concurrency limit + origin check)
1110/// and the [`AppRunParams`] needed to drive it.
1111///
1112/// This is the shared core of [`serve`] and [`serve_with_listener`].
1113/// It performs *no* network I/O: callers are responsible for binding
1114/// (or accepting a pre-bound) [`TcpListener`] and invoking
1115/// [`run_server`].
1116#[allow(
1117    clippy::cognitive_complexity,
1118    reason = "router assembly is intrinsically sequential; splitting harms readability"
1119)]
1120#[allow(
1121    deprecated,
1122    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
1123)]
1124fn build_app_router<H, F>(
1125    mut config: McpServerConfig,
1126    handler_factory: F,
1127) -> anyhow::Result<(axum::Router, AppRunParams)>
1128where
1129    H: ServerHandler + 'static,
1130    F: Fn() -> H + Send + Sync + Clone + 'static,
1131{
1132    let ct = CancellationToken::new();
1133
1134    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
1135    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
1136
1137    let mcp_service = StreamableHttpService::new(
1138        move || Ok(handler_factory()),
1139        {
1140            let mut mgr = LocalSessionManager::default();
1141            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
1142            mgr.into()
1143        },
1144        StreamableHttpServerConfig::default()
1145            .with_allowed_hosts(allowed_hosts)
1146            .with_sse_keep_alive(Some(config.sse_keep_alive))
1147            .with_cancellation_token(ct.child_token()),
1148    );
1149
1150    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
1151    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1152
1153    // Build auth state eagerly when auth is configured so we can wire both
1154    // the auth middleware *and* the optional admin router against the same
1155    // state. The middleware itself is installed further down in layer order.
1156    let auth_state: Option<Arc<AuthState>> = match config.auth {
1157        Some(ref auth_config) if auth_config.enabled => {
1158            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1159            let pre_auth_limiter = auth_config
1160                .rate_limit
1161                .as_ref()
1162                .map(crate::auth::build_pre_auth_limiter);
1163
1164            #[cfg(feature = "oauth")]
1165            let jwks_cache = auth_config
1166                .oauth
1167                .as_ref()
1168                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1169                .transpose()
1170                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1171
1172            Some(Arc::new(AuthState {
1173                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1174                rate_limiter,
1175                pre_auth_limiter,
1176                #[cfg(feature = "oauth")]
1177                jwks_cache,
1178                seen_identities: crate::auth::SeenIdentitySet::new(),
1179                counters: crate::auth::AuthCounters::default(),
1180            }))
1181        }
1182        _ => None,
1183    };
1184
1185    // Build the RBAC policy swap early so the admin router and the later
1186    // RBAC middleware layer share the same hot-reloadable state.
1187    let rbac_swap = Arc::new(ArcSwap::new(
1188        config
1189            .rbac
1190            .clone()
1191            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1192    ));
1193
1194    // Optional /admin/* diagnostic routes. Merged BEFORE the
1195    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
1196    if config.admin_enabled {
1197        let Some(ref auth_state_ref) = auth_state else {
1198            return Err(anyhow::anyhow!(
1199                "admin_enabled=true requires auth to be configured and enabled"
1200            ));
1201        };
1202        let admin_state = crate::admin::AdminState {
1203            started_at: std::time::Instant::now(),
1204            name: config.name.clone(),
1205            version: config.version.clone(),
1206            auth: Some(Arc::clone(auth_state_ref)),
1207            rbac: Arc::clone(&rbac_swap),
1208        };
1209        let admin_cfg = crate::admin::AdminConfig {
1210            role: config.admin_role.clone(),
1211        };
1212        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1213        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1214    }
1215
1216    // ----- Middleware order (CRITICAL: read carefully) ------------------
1217    //
1218    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
1219    // added is the OUTERMOST (runs first on a request). To achieve a
1220    // request-time flow of:
1221    //
1222    //   body-limit -> timeout -> auth -> rbac -> handler
1223    //
1224    // we add layers in the REVERSE order:
1225    //
1226    //   1. RBAC               (innermost, runs last before handler)
1227    //   2. auth               (parses identity, sets extension for RBAC)
1228    //   3. timeout            (bounds total request time)
1229    //   4. body-limit         (outermost on /mcp; caps payload before
1230    //                          anything else reads/buffers it)
1231    //
1232    // Origin validation is installed on the OUTER router (after the
1233    // /mcp router is merged in), so it also protects /healthz, /readyz,
1234    // /version, and any OAuth proxy endpoints.
1235    //
1236    // Rationale:
1237    // - Body-limit must be outermost on /mcp so RBAC (which reads the
1238    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
1239    // - Auth must run before RBAC because RBAC consumes
1240    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
1241    //   policy.
1242    // - Origin runs before auth so we reject cross-origin requests
1243    //   without spending Argon2 cycles on unauthenticated callers.
1244
1245    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
1246    // Always installed: even when RBAC is disabled, tool rate limiting may
1247    // be active (MCP spec: servers MUST rate limit tool invocations).
1248    {
1249        let tool_limiter: Option<Arc<ToolRateLimiter>> =
1250            config.tool_rate_limit.map(build_tool_rate_limiter);
1251
1252        if rbac_swap.load().is_enabled() {
1253            tracing::info!("RBAC enforcement enabled on /mcp");
1254        }
1255        if let Some(limit) = config.tool_rate_limit {
1256            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1257        }
1258
1259        let rbac_for_mw = Arc::clone(&rbac_swap);
1260        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1261            let p = rbac_for_mw.load_full();
1262            let tl = tool_limiter.clone();
1263            rbac_middleware(p, tl, req, next)
1264        }));
1265    }
1266
1267    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
1268    if let Some(ref auth_config) = config.auth
1269        && auth_config.enabled
1270    {
1271        let Some(ref state) = auth_state else {
1272            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1273        };
1274
1275        let methods: Vec<&str> = [
1276            auth_config.mtls.is_some().then_some("mTLS"),
1277            (!auth_config.api_keys.is_empty()).then_some("bearer"),
1278            #[cfg(feature = "oauth")]
1279            auth_config.oauth.is_some().then_some("oauth-jwt"),
1280        ]
1281        .into_iter()
1282        .flatten()
1283        .collect();
1284
1285        tracing::info!(
1286            methods = %methods.join(", "),
1287            api_keys = auth_config.api_keys.len(),
1288            "auth enabled on /mcp"
1289        );
1290
1291        let state_for_mw = Arc::clone(state);
1292        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1293            let s = Arc::clone(&state_for_mw);
1294            auth_middleware(s, req, next)
1295        }));
1296    }
1297
1298    // [3] Request timeout (returns 408 on expiry). Bounds total request
1299    // duration including auth + handler.
1300    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1301        axum::http::StatusCode::REQUEST_TIMEOUT,
1302        config.request_timeout,
1303    ));
1304
1305    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
1306    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
1307    // attempts to buffer or parse the body.
1308    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1309        config.max_request_body,
1310    ));
1311
1312    // Compute the effective allowed-origins list for the outer
1313    // origin-check layer (installed on the merged router below). When
1314    // `allowed_origins` is empty but `public_url` is set, auto-derive
1315    // the origin from the public URL so MCP clients (e.g. Claude Code)
1316    // that send `Origin: <server-url>` are accepted without explicit
1317    // config.
1318    let mut effective_origins = config.allowed_origins.clone();
1319    if effective_origins.is_empty()
1320        && let Some(ref url) = config.public_url
1321    {
1322        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1323        // Strip any path/query from the public URL. Offsets come from
1324        // `find`, so they are char-boundary-aligned; `get(..)` keeps that
1325        // machine-checked (a violation degrades to an empty slice).
1326        if let Some(scheme_end) = url.find("://") {
1327            let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1328            let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1329            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1330            let host = after_scheme.get(..host_end).unwrap_or_default();
1331            let origin = format!("{scheme_with_sep}{host}");
1332            tracing::info!(
1333                %origin,
1334                "auto-derived allowed origin from public_url"
1335            );
1336            effective_origins.push(origin);
1337        }
1338    }
1339    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1340    let cors_origins = Arc::clone(&allowed_origins);
1341    let log_request_headers = config.log_request_headers;
1342
1343    let readyz_route = if let Some(check) = config.readiness_check.take() {
1344        axum::routing::get(move || readyz(Arc::clone(&check)))
1345    } else {
1346        axum::routing::get(healthz)
1347    };
1348
1349    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1350    let mut router = axum::Router::new()
1351        .route("/healthz", axum::routing::get(healthz))
1352        .route("/readyz", readyz_route)
1353        .route(
1354            "/version",
1355            axum::routing::get({
1356                // Pre-serialize the version payload once at router-build
1357                // time. The handler then serves a cheap `Arc::clone` of the
1358                // immutable bytes per request, avoiding `serde_json::Value`
1359                // allocation + serialization on every `/version` hit.
1360                let payload_bytes: Arc<[u8]> =
1361                    serialize_version_payload(&config.name, &config.version);
1362                move || {
1363                    let p = Arc::clone(&payload_bytes);
1364                    async move {
1365                        (
1366                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1367                            p.to_vec(),
1368                        )
1369                    }
1370                }
1371            }),
1372        )
1373        .merge(mcp_router);
1374
1375    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1376    // When configured, wrap them — and only them — in the per-IP rate
1377    // limiter BEFORE merging: axum layers wrap only the routes already
1378    // present on the sub-router, so the limiter can never leak onto
1379    // `/mcp`, health, admin, or OAuth endpoints, while top-level layers
1380    // (origin check, peer-address normalization, ...) still run first.
1381    if let Some(extra) = config.extra_router.take() {
1382        let extra = match config.extra_route_rate_limit {
1383            Some(per_minute) => {
1384                let limiter = build_extra_route_rate_limiter(per_minute);
1385                tracing::info!(per_minute, "extra-route per-IP rate limit enabled");
1386                extra.layer(axum::middleware::from_fn(move |req, next| {
1387                    let l = Arc::clone(&limiter);
1388                    extra_route_rate_limit_middleware(l, req, next)
1389                }))
1390            }
1391            None => extra,
1392        };
1393        router = router.merge(extra);
1394    }
1395
1396    // RFC 9728: Protected Resource Metadata endpoint.
1397    // When OAuth is configured, serve full metadata with authorization_servers.
1398    // Otherwise, serve a minimal document with just the resource URL and no
1399    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1400    // that the server exists but does NOT require OAuth authentication,
1401    // preventing them from gating the connection behind a broken auth flow.
1402    let server_url = if let Some(ref url) = config.public_url {
1403        url.trim_end_matches('/').to_owned()
1404    } else {
1405        let prm_scheme = if config.tls_cert_path.is_some() {
1406            "https"
1407        } else {
1408            "http"
1409        };
1410        format!("{prm_scheme}://{}", config.bind_addr)
1411    };
1412    let resource_url = format!("{server_url}/mcp");
1413
1414    #[cfg(feature = "oauth")]
1415    let prm_metadata = if let Some(ref auth_config) = config.auth
1416        && let Some(ref oauth_config) = auth_config.oauth
1417    {
1418        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1419    } else {
1420        serde_json::json!({ "resource": resource_url })
1421    };
1422    #[cfg(not(feature = "oauth"))]
1423    let prm_metadata = serde_json::json!({ "resource": resource_url });
1424
1425    router = router.route(
1426        "/.well-known/oauth-protected-resource",
1427        axum::routing::get(move || {
1428            let m = prm_metadata.clone();
1429            async move { axum::Json(m) }
1430        }),
1431    );
1432
1433    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1434    // /authorize, /token, /register, and authorization server metadata so
1435    // MCP clients can perform Authorization Code + PKCE against the upstream
1436    // IdP (e.g. Keycloak) transparently.
1437    #[cfg(feature = "oauth")]
1438    if let Some(ref auth_config) = config.auth
1439        && let Some(ref oauth_config) = auth_config.oauth
1440        && oauth_config.proxy.is_some()
1441    {
1442        router =
1443            install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1444    }
1445
1446    // OWASP security response headers (applied to all responses).
1447    // HSTS is conditional on TLS being configured.
1448    let is_tls = config.tls_cert_path.is_some();
1449    let security_headers_cfg = Arc::new(config.security_headers.clone());
1450    router = router.layer(axum::middleware::from_fn(move |req, next| {
1451        let cfg = Arc::clone(&security_headers_cfg);
1452        security_headers_middleware(is_tls, cfg, req, next)
1453    }));
1454
1455    // CORS preflight layer (required for browser-based MCP clients).
1456    // Uses the same effective origins as the origin check middleware
1457    // (including auto-derived origin from public_url).
1458    if !cors_origins.is_empty() {
1459        let cors = tower_http::cors::CorsLayer::new()
1460            .allow_origin(
1461                cors_origins
1462                    .iter()
1463                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1464                    .collect::<Vec<_>>(),
1465            )
1466            .allow_methods([
1467                axum::http::Method::GET,
1468                axum::http::Method::POST,
1469                axum::http::Method::OPTIONS,
1470            ])
1471            .allow_headers([
1472                axum::http::header::CONTENT_TYPE,
1473                axum::http::header::AUTHORIZATION,
1474            ]);
1475        router = router.layer(cors);
1476    }
1477
1478    // Optional response compression (gzip + brotli). Skips small bodies
1479    // to avoid overhead. Applied after CORS so preflight responses remain
1480    // uncompressed.
1481    if config.compression_enabled {
1482        use tower_http::compression::Predicate as _;
1483        let predicate = tower_http::compression::DefaultPredicate::new().and(
1484            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1485        );
1486        router = router.layer(
1487            tower_http::compression::CompressionLayer::new()
1488                .gzip(true)
1489                .br(true)
1490                .compress_when(predicate),
1491        );
1492        tracing::info!(
1493            min_size = config.compression_min_size,
1494            "response compression enabled (gzip, br)"
1495        );
1496    }
1497
1498    // Optional global concurrency cap. `load_shed` converts the
1499    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1500    if let Some(max) = config.max_concurrent_requests {
1501        let overload_handler = tower::ServiceBuilder::new()
1502            .layer(axum::error_handling::HandleErrorLayer::new(
1503                |_err: tower::BoxError| async {
1504                    (
1505                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1506                        axum::Json(serde_json::json!({
1507                            "error": "overloaded",
1508                            "error_description": "server is at capacity, retry later"
1509                        })),
1510                    )
1511                },
1512            ))
1513            .layer(tower::load_shed::LoadShedLayer::new())
1514            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1515        router = router.layer(overload_handler);
1516        tracing::info!(max, "global concurrency limit enabled");
1517    }
1518
1519    // JSON fallback for unmatched routes. Without this, axum returns
1520    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1521    // when they probe OAuth endpoints like /authorize or /token.
1522    router = router.fallback(|| async {
1523        (
1524            axum::http::StatusCode::NOT_FOUND,
1525            axum::Json(serde_json::json!({
1526                "error": "not_found",
1527                "error_description": "The requested endpoint does not exist"
1528            })),
1529        )
1530    });
1531
1532    // Prometheus metrics: recording middleware + separate listener.
1533    #[cfg(feature = "metrics")]
1534    if config.metrics_enabled {
1535        let metrics = Arc::new(
1536            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1537        );
1538        let m = Arc::clone(&metrics);
1539        router = router.layer(axum::middleware::from_fn(
1540            move |req: Request<Body>, next: Next| {
1541                let m = Arc::clone(&m);
1542                metrics_middleware(m, req, next)
1543            },
1544        ));
1545        let metrics_bind = config.metrics_bind.clone();
1546        let metrics_shutdown = ct.clone();
1547        tokio::spawn(async move {
1548            if let Err(e) =
1549                crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1550            {
1551                tracing::error!("metrics listener failed: {e}");
1552            }
1553        });
1554    }
1555
1556    // Peer-address normalization. Mirrors the TLS branch's peer address
1557    // into `ConnectInfo<SocketAddr>` and exposes the framework-owned
1558    // `PeerAddr` extension on both listener branches, so ALL routes on
1559    // the merged router (`/mcp`, `/healthz`, OAuth proxy endpoints,
1560    // admin endpoints, extra_router, ...) and all inner middleware see a
1561    // uniform peer-address contract regardless of TLS. Installed just
1562    // inside the origin check, which stays outermost by design.
1563    router = router.layer(axum::middleware::from_fn(normalize_peer_addr_middleware));
1564
1565    // Origin validation layer (MCP spec: servers MUST validate the
1566    // Origin header to prevent DNS rebinding attacks). Installed as the
1567    // OUTERMOST layer on the OUTER router so it protects ALL routes
1568    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1569    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1570    // reject cross-origin attackers without spending Argon2 cycles.
1571    //
1572    // Origin-less requests (e.g. server-to-server probes, curl, native
1573    // MCP clients) are permitted; only requests with an Origin header
1574    // that does not match `effective_origins` are rejected.
1575    router = router.layer(axum::middleware::from_fn(move |req, next| {
1576        let origins = Arc::clone(&allowed_origins);
1577        origin_check_middleware(origins, log_request_headers, req, next)
1578    }));
1579
1580    let scheme = if config.tls_cert_path.is_some() {
1581        "https"
1582    } else {
1583        "http"
1584    };
1585
1586    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1587        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1588        _ => None,
1589    };
1590    let tls_handshake_timeout = config.tls_handshake_timeout;
1591    let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1592    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1593
1594    Ok((
1595        router,
1596        AppRunParams {
1597            tls_paths,
1598            tls_handshake_timeout,
1599            max_concurrent_tls_handshakes,
1600            mtls_config,
1601            shutdown_timeout: config.shutdown_timeout,
1602            auth_state,
1603            rbac_swap,
1604            on_reload_ready: config.on_reload_ready.take(),
1605            ct,
1606            scheme,
1607            name: config.name.clone(),
1608        },
1609    ))
1610}
1611
1612/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1613/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1614///
1615/// This is the standard entry point for production deployments. For
1616/// deterministic shutdown control (e.g. integration tests), see
1617/// [`serve_with_listener`].
1618///
1619/// The configuration must be validated first via
1620/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1621/// token. This typestate guarantees, at compile time, that the server
1622/// never starts with an invalid configuration.
1623///
1624/// # Errors
1625///
1626/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1627/// fails, or if the underlying axum server returns an error.
1628pub async fn serve<H, F>(
1629    config: Validated<McpServerConfig>,
1630    handler_factory: F,
1631) -> Result<(), McpxError>
1632where
1633    H: ServerHandler + 'static,
1634    F: Fn() -> H + Send + Sync + Clone + 'static,
1635{
1636    let config = config.into_inner();
1637    #[allow(
1638        deprecated,
1639        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1640    )]
1641    let bind_addr = config.bind_addr.clone();
1642    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1643
1644    let listener = TcpListener::bind(&bind_addr)
1645        .await
1646        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1647    log_listening(&params.name, params.scheme, &bind_addr);
1648
1649    run_server(
1650        router,
1651        listener,
1652        params.tls_paths,
1653        params.tls_handshake_timeout,
1654        params.max_concurrent_tls_handshakes,
1655        params.mtls_config,
1656        params.shutdown_timeout,
1657        params.auth_state,
1658        params.rbac_swap,
1659        params.on_reload_ready,
1660        params.ct,
1661    )
1662    .await
1663    .map_err(anyhow_to_startup)
1664}
1665
1666/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1667/// readiness signalling and external shutdown control.
1668///
1669/// This variant is intended for **deterministic integration tests** and
1670/// for embedders that need to bind the listening socket themselves
1671/// (e.g. systemd socket activation). Compared to [`serve`]:
1672///
1673/// * The caller passes a `TcpListener` that is already bound. This
1674///   eliminates the bind race in tests that previously required
1675///   poll-the-`/healthz`-loop start-up detection.
1676/// * `ready_tx`, when `Some`, receives the socket's
1677///   [`SocketAddr`] *after* the router is built and immediately before
1678///   the server starts accepting connections. Tests can `await` the
1679///   matching `oneshot::Receiver` to know exactly when it is safe to
1680///   issue requests.
1681/// * `shutdown`, when `Some`, gives the caller a
1682///   [`CancellationToken`] that triggers the same graceful-shutdown
1683///   path as a real OS signal. This avoids cross-platform issues with
1684///   sending real `SIGTERM` from tests on Windows.
1685///
1686/// All three optional parameters degrade gracefully: if `ready_tx` is
1687/// `None`, no signal is sent; if `shutdown` is `None`, the server only
1688/// stops on an OS signal (just like [`serve`]).
1689///
1690/// # Errors
1691///
1692/// Returns [`McpxError::Startup`] if router construction fails, if reading
1693/// the listener's `local_addr()` fails, or if the underlying axum
1694/// server returns an error.
1695pub async fn serve_with_listener<H, F>(
1696    listener: TcpListener,
1697    config: Validated<McpServerConfig>,
1698    handler_factory: F,
1699    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1700    shutdown: Option<CancellationToken>,
1701) -> Result<(), McpxError>
1702where
1703    H: ServerHandler + 'static,
1704    F: Fn() -> H + Send + Sync + Clone + 'static,
1705{
1706    let config = config.into_inner();
1707    let local_addr = listener
1708        .local_addr()
1709        .map_err(|e| io_to_startup("listener.local_addr", e))?;
1710    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1711
1712    log_listening(&params.name, params.scheme, &local_addr.to_string());
1713
1714    // Forward external shutdown into the server-internal cancellation
1715    // token so `run_server`'s shutdown trigger picks it up alongside
1716    // any real OS signal.
1717    if let Some(external) = shutdown {
1718        let internal = params.ct.clone();
1719        tokio::spawn(async move {
1720            external.cancelled().await;
1721            internal.cancel();
1722        });
1723    }
1724
1725    // Signal readiness *after* the router is fully built and external
1726    // shutdown is wired, but *before* run_server takes ownership of
1727    // the listener. The receiver can immediately issue requests.
1728    if let Some(tx) = ready_tx {
1729        // Receiver may have been dropped (test gave up). That's fine.
1730        let _ = tx.send(local_addr);
1731    }
1732
1733    run_server(
1734        router,
1735        listener,
1736        params.tls_paths,
1737        params.tls_handshake_timeout,
1738        params.max_concurrent_tls_handshakes,
1739        params.mtls_config,
1740        params.shutdown_timeout,
1741        params.auth_state,
1742        params.rbac_swap,
1743        params.on_reload_ready,
1744        params.ct,
1745    )
1746    .await
1747    .map_err(anyhow_to_startup)
1748}
1749
1750/// Emit the standard "listening on …" log lines used by both
1751/// [`serve`] and [`serve_with_listener`].
1752#[allow(
1753    clippy::cognitive_complexity,
1754    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1755)]
1756fn log_listening(name: &str, scheme: &str, addr: &str) {
1757    tracing::info!("{name} listening on {addr}");
1758    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
1759    tracing::info!("  Health check: {scheme}://{addr}/healthz");
1760    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
1761}
1762
1763/// Drive the chosen axum server variant (TLS or plain) with a graceful
1764/// shutdown window. Consumes the router and listener.
1765///
1766/// # Shutdown semantics
1767///
1768/// A single shutdown trigger (the FIRST of: OS signal via
1769/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
1770///
1771/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
1772///    accepting new connections and waits for in-flight requests to
1773///    drain;
1774/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
1775///    drainage exceeds `shutdown_timeout`.
1776///
1777/// Previously this function awaited `shutdown_signal()` independently
1778/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
1779/// resolves once per future and consumes one signal, the force-exit
1780/// timer was tied to a SECOND signal (a second SIGTERM the operator
1781/// would never send). Under a single SIGTERM the graceful drain could
1782/// hang indefinitely. The current implementation derives both branches
1783/// from a single shared trigger so the timeout race is anchored to the
1784/// FIRST (and only) signal.
1785#[allow(
1786    clippy::too_many_arguments,
1787    clippy::cognitive_complexity,
1788    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1789)]
1790async fn run_server(
1791    router: axum::Router,
1792    listener: TcpListener,
1793    tls_paths: Option<(PathBuf, PathBuf)>,
1794    tls_handshake_timeout: Duration,
1795    max_concurrent_tls_handshakes: usize,
1796    mtls_config: Option<MtlsConfig>,
1797    shutdown_timeout: Duration,
1798    auth_state: Option<Arc<AuthState>>,
1799    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1800    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1801    ct: CancellationToken,
1802) -> anyhow::Result<()> {
1803    // `shutdown_trigger` fires when the FIRST source resolves: either
1804    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
1805    // (which the test harness uses for deterministic shutdown).
1806    let shutdown_trigger = CancellationToken::new();
1807    {
1808        let trigger = shutdown_trigger.clone();
1809        let parent = ct.clone();
1810        tokio::spawn(async move {
1811            tokio::select! {
1812                () = shutdown_signal() => {}
1813                () = parent.cancelled() => {}
1814            }
1815            trigger.cancel();
1816        });
1817    }
1818
1819    let graceful = {
1820        let trigger = shutdown_trigger.clone();
1821        let ct = ct.clone();
1822        async move {
1823            trigger.cancelled().await;
1824            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1825            ct.cancel();
1826        }
1827    };
1828
1829    let force_exit_timer = {
1830        let trigger = shutdown_trigger.clone();
1831        async move {
1832            trigger.cancelled().await;
1833            tokio::time::sleep(shutdown_timeout).await;
1834        }
1835    };
1836
1837    if let Some((cert_path, key_path)) = tls_paths {
1838        let crl_set = if let Some(mtls) = mtls_config.as_ref()
1839            && mtls.crl_enabled
1840        {
1841            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1842            let (crl_set, discover_rx) =
1843                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1844                    .await
1845                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1846            tokio::spawn(mtls_revocation::run_crl_refresher(
1847                Arc::clone(&crl_set),
1848                discover_rx,
1849                ct.clone(),
1850            ));
1851            Some(crl_set)
1852        } else {
1853            None
1854        };
1855
1856        if let Some(cb) = on_reload_ready.take() {
1857            cb(ReloadHandle {
1858                auth: auth_state.clone(),
1859                rbac: Some(Arc::clone(&rbac_swap)),
1860                crl_set: crl_set.clone(),
1861            });
1862        }
1863
1864        let tls_listener = TlsListener::new(
1865            listener,
1866            &cert_path,
1867            &key_path,
1868            mtls_config.as_ref(),
1869            crl_set,
1870            tls_handshake_timeout,
1871            max_concurrent_tls_handshakes,
1872        )?;
1873        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1874        tokio::select! {
1875            result = axum::serve(tls_listener, make_svc)
1876                .with_graceful_shutdown(graceful) => { result?; }
1877            () = force_exit_timer => {
1878                tracing::warn!("shutdown timeout exceeded, forcing exit");
1879            }
1880        }
1881    } else {
1882        if let Some(cb) = on_reload_ready.take() {
1883            cb(ReloadHandle {
1884                auth: auth_state,
1885                rbac: Some(rbac_swap),
1886                crl_set: None,
1887            });
1888        }
1889
1890        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1891        tokio::select! {
1892            result = axum::serve(listener, make_svc)
1893                .with_graceful_shutdown(graceful) => { result?; }
1894            () = force_exit_timer => {
1895                tracing::warn!("shutdown timeout exceeded, forcing exit");
1896            }
1897        }
1898    }
1899
1900    Ok(())
1901}
1902
1903/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
1904/// `/register`, and authorization server metadata) on `router`. The
1905/// caller must ensure `oauth_config.proxy` is `Some`.
1906///
1907/// # Errors
1908///
1909/// Returns [`McpxError::Startup`] if the shared
1910/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
1911#[cfg(feature = "oauth")]
1912fn install_oauth_proxy_routes(
1913    router: axum::Router,
1914    server_url: &str,
1915    oauth_config: &crate::oauth::OAuthConfig,
1916    auth_state: Option<&Arc<AuthState>>,
1917) -> Result<axum::Router, McpxError> {
1918    let Some(ref proxy) = oauth_config.proxy else {
1919        return Ok(router);
1920    };
1921
1922    // Single shared HTTP client for all proxy endpoints. Cloning is
1923    // cheap (refcounted) and shares the underlying connection pool.
1924    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1925
1926    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1927    let router = router.route(
1928        "/.well-known/oauth-authorization-server",
1929        axum::routing::get(move || {
1930            let m = asm.clone();
1931            async move { axum::Json(m) }
1932        }),
1933    );
1934
1935    let proxy_authorize = proxy.clone();
1936    let router = router.route(
1937        "/authorize",
1938        axum::routing::get(
1939            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1940                let p = proxy_authorize.clone();
1941                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1942            },
1943        ),
1944    );
1945
1946    let proxy_token = proxy.clone();
1947    let token_http = http.clone();
1948    let router = router.route(
1949        "/token",
1950        axum::routing::post(move |body: String| {
1951            let p = proxy_token.clone();
1952            let h = token_http.clone();
1953            async move { crate::oauth::handle_token(&h, &p, &body).await }
1954        })
1955        .layer(axum::middleware::from_fn(
1956            oauth_token_cache_headers_middleware,
1957        )),
1958    );
1959
1960    let proxy_register = proxy.clone();
1961    let router = router.route(
1962        "/register",
1963        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1964            let p = proxy_register;
1965            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1966        })
1967        .layer(axum::middleware::from_fn(
1968            oauth_token_cache_headers_middleware,
1969        )),
1970    );
1971
1972    let admin_routes_enabled = proxy.expose_admin_endpoints
1973        && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1974    if proxy.expose_admin_endpoints
1975        && !proxy.require_auth_on_admin_endpoints
1976        && proxy.allow_unauthenticated_admin_endpoints
1977    {
1978        // M3 escape-hatch in effect: validate() let this through because
1979        // the operator explicitly opted in. Surface it loudly at startup
1980        // so the choice is auditable in logs.
1981        tracing::warn!(
1982            "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1983             allow_unauthenticated_admin_endpoints opt-out; ensure an \
1984             authenticated reverse proxy fronts these routes"
1985        );
1986    }
1987
1988    let admin_router = if admin_routes_enabled {
1989        build_oauth_admin_router(proxy, http, auth_state)?
1990    } else {
1991        axum::Router::new()
1992    };
1993
1994    let router = router.merge(admin_router);
1995
1996    tracing::info!(
1997        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1998        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1999        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
2000    );
2001    Ok(router)
2002}
2003
2004/// Build the optional `/introspect` + `/revoke` admin sub-router.
2005///
2006/// Layered with [`oauth_token_cache_headers_middleware`] so RFC 6749 §5.1
2007/// / RFC 6750 §5.4 cache headers are emitted, and conditionally with the
2008/// auth middleware when `proxy.require_auth_on_admin_endpoints` is set.
2009#[cfg(feature = "oauth")]
2010fn build_oauth_admin_router(
2011    proxy: &crate::oauth::OAuthProxyConfig,
2012    http: crate::oauth::OauthHttpClient,
2013    auth_state: Option<&Arc<AuthState>>,
2014) -> Result<axum::Router, McpxError> {
2015    let mut admin_router = axum::Router::new();
2016    if proxy.introspection_url.is_some() {
2017        let proxy_introspect = proxy.clone();
2018        let introspect_http = http.clone();
2019        admin_router = admin_router.route(
2020            "/introspect",
2021            axum::routing::post(move |body: String| {
2022                let p = proxy_introspect.clone();
2023                let h = introspect_http.clone();
2024                async move { crate::oauth::handle_introspect(&h, &p, &body).await }
2025            }),
2026        );
2027    }
2028    if proxy.revocation_url.is_some() {
2029        let proxy_revoke = proxy.clone();
2030        let revoke_http = http;
2031        admin_router = admin_router.route(
2032            "/revoke",
2033            axum::routing::post(move |body: String| {
2034                let p = proxy_revoke.clone();
2035                let h = revoke_http.clone();
2036                async move { crate::oauth::handle_revoke(&h, &p, &body).await }
2037            }),
2038        );
2039    }
2040
2041    let admin_router = admin_router.layer(axum::middleware::from_fn(
2042        oauth_token_cache_headers_middleware,
2043    ));
2044
2045    if proxy.require_auth_on_admin_endpoints {
2046        let Some(state) = auth_state else {
2047            return Err(McpxError::Startup(
2048                "oauth proxy admin endpoints require auth state".into(),
2049            ));
2050        };
2051        let state_for_mw = Arc::clone(state);
2052        Ok(
2053            admin_router.layer(axum::middleware::from_fn(move |req, next| {
2054                let s = Arc::clone(&state_for_mw);
2055                auth_middleware(s, req, next)
2056            })),
2057        )
2058    } else {
2059        Ok(admin_router)
2060    }
2061}
2062
2063/// Build the host allow-list for rmcp's DNS rebinding protection.
2064///
2065/// Includes loopback hosts by default, then augments with host/authority
2066/// derived from `public_url` and the server bind address.
2067fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
2068    let mut hosts = vec![
2069        "localhost".to_owned(),
2070        "127.0.0.1".to_owned(),
2071        "::1".to_owned(),
2072    ];
2073
2074    if let Some(url) = public_url
2075        && let Ok(uri) = url.parse::<axum::http::Uri>()
2076        && let Some(authority) = uri.authority()
2077    {
2078        let host = authority.host().to_owned();
2079        if !hosts.iter().any(|h| h == &host) {
2080            hosts.push(host);
2081        }
2082
2083        let authority = authority.as_str().to_owned();
2084        if !hosts.iter().any(|h| h == &authority) {
2085            hosts.push(authority);
2086        }
2087    }
2088
2089    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
2090        && let Some(authority) = uri.authority()
2091    {
2092        let host = authority.host().to_owned();
2093        if !hosts.iter().any(|h| h == &host) {
2094            hosts.push(host);
2095        }
2096
2097        let authority = authority.as_str().to_owned();
2098        if !hosts.iter().any(|h| h == &authority) {
2099            hosts.push(authority);
2100        }
2101    }
2102
2103    hosts
2104}
2105
2106// - TLS support -
2107
2108/// Implement axum's `Connected` trait for `TlsConnInfo` so that
2109/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
2110/// over our custom `TlsListener`.
2111///
2112/// The identity is read directly from the wrapping
2113/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
2114/// between the TLS connection and its mTLS identity. This eliminates the
2115/// previous shared-map approach which was vulnerable to ephemeral-port
2116/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
2117/// pair could alias a stale entry).
2118impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
2119    for TlsConnInfo
2120{
2121    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
2122        let addr = *target.remote_addr();
2123        let identity = target.io().identity().cloned();
2124        TlsConnInfo::new(addr, identity)
2125    }
2126}
2127
2128/// Default per-handshake deadline on the TLS accept path. Prevents idle
2129/// or slow-loris connections from pinning handshake worker tasks (and
2130/// their semaphore permits) indefinitely.
2131///
2132/// Configurable since 1.9.0 via
2133/// [`McpServerConfig::with_tls_handshake_timeout`].
2134const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
2135
2136/// Default upper bound on concurrently in-flight TLS handshakes. When
2137/// saturated, the acceptor task stops pulling new connections from the
2138/// kernel backlog (backpressure) instead of accepting and dropping them
2139/// in user space.
2140///
2141/// Configurable since 1.9.0 via
2142/// [`McpServerConfig::with_max_concurrent_tls_handshakes`].
2143const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
2144
2145/// Capacity of the completed-handshake queue between the acceptor task and
2146/// `axum::serve`'s `accept()` loop. Handshake workers block on `send` when
2147/// the queue is full, so a slow accept loop back-pressures handshakes
2148/// rather than buffering completed connections unboundedly.
2149const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
2150
2151/// A TLS-wrapping listener that implements axum's `Listener` trait.
2152///
2153/// TCP accepts and TLS handshakes run on a dedicated background task: each
2154/// accepted connection's handshake is spawned onto its own worker task,
2155/// bounded by a configurable concurrent-handshake cap (default
2156/// [`DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES`]) and a per-handshake timeout
2157/// (default [`DEFAULT_TLS_HANDSHAKE_TIMEOUT`]). A slow or idle client
2158/// therefore cannot stall other connections behind a serialized inline
2159/// handshake.
2160///
2161/// When mTLS is configured, client certificates are verified against the
2162/// configured CA and the client identity is extracted at handshake time.
2163/// The extracted identity is bound to the connection itself via the
2164/// returned [`AuthenticatedTlsStream`], so it is impossible for an
2165/// unrelated connection to observe it.
2166struct TlsListener {
2167    /// Bound address, captured eagerly before the `TcpListener` moves into
2168    /// the acceptor task.
2169    local_addr: SocketAddr,
2170    /// Completed handshakes produced by the acceptor task's workers.
2171    rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2172    /// Background task driving TCP accepts and concurrent TLS handshakes.
2173    /// Aborted on drop so the listener releases its port deterministically.
2174    acceptor_task: tokio::task::JoinHandle<()>,
2175}
2176
2177impl TlsListener {
2178    fn new(
2179        inner: TcpListener,
2180        cert_path: &Path,
2181        key_path: &Path,
2182        mtls_config: Option<&MtlsConfig>,
2183        crl_set: Option<Arc<CrlSet>>,
2184        handshake_timeout: Duration,
2185        max_concurrent_handshakes: usize,
2186    ) -> anyhow::Result<Self> {
2187        // Install the ring crypto provider (ok to call multiple times).
2188        rustls::crypto::ring::default_provider()
2189            .install_default()
2190            .ok();
2191
2192        let certs = load_certs(cert_path)?;
2193        let key = load_key(key_path)?;
2194
2195        let mtls_default_role;
2196
2197        let tls_config = if let Some(mtls) = mtls_config {
2198            mtls_default_role = mtls.default_role.clone();
2199            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2200            {
2201                let Some(crl_set) = crl_set else {
2202                    return Err(anyhow::anyhow!(
2203                        "mTLS CRL verifier requested but CRL state was not initialized"
2204                    ));
2205                };
2206                Arc::new(DynamicClientCertVerifier::new(crl_set))
2207            } else {
2208                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2209                if mtls.required {
2210                    rustls::server::WebPkiClientVerifier::builder(root_store)
2211                        .build()
2212                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2213                } else {
2214                    rustls::server::WebPkiClientVerifier::builder(root_store)
2215                        .allow_unauthenticated()
2216                        .build()
2217                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2218                }
2219            };
2220
2221            tracing::info!(
2222                ca = %mtls.ca_cert_path.display(),
2223                required = mtls.required,
2224                crl_enabled = mtls.crl_enabled,
2225                "mTLS client auth configured"
2226            );
2227
2228            rustls::ServerConfig::builder_with_protocol_versions(&[
2229                &rustls::version::TLS12,
2230                &rustls::version::TLS13,
2231            ])
2232            .with_client_cert_verifier(verifier)
2233            .with_single_cert(certs, key)?
2234        } else {
2235            mtls_default_role = "viewer".to_owned();
2236            rustls::ServerConfig::builder_with_protocol_versions(&[
2237                &rustls::version::TLS12,
2238                &rustls::version::TLS13,
2239            ])
2240            .with_no_client_auth()
2241            .with_single_cert(certs, key)?
2242        };
2243
2244        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2245        tracing::info!(
2246            "TLS enabled (cert: {}, key: {})",
2247            cert_path.display(),
2248            key_path.display()
2249        );
2250        let local_addr = inner.local_addr()?;
2251        let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2252        let acceptor_task = tokio::spawn(run_tls_acceptor(
2253            inner,
2254            acceptor,
2255            mtls_default_role,
2256            tx,
2257            handshake_timeout,
2258            max_concurrent_handshakes,
2259        ));
2260        Ok(Self {
2261            local_addr,
2262            rx,
2263            acceptor_task,
2264        })
2265    }
2266
2267    /// Extract the mTLS client cert identity from a completed TLS handshake.
2268    /// Returns `None` if no client certificate was presented or if the
2269    /// certificate could not be parsed into an [`AuthIdentity`].
2270    fn extract_handshake_identity(
2271        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2272        default_role: &str,
2273        addr: SocketAddr,
2274    ) -> Option<AuthIdentity> {
2275        let (_, server_conn) = tls_stream.get_ref();
2276        let cert_der = server_conn.peer_certificates()?.first()?;
2277        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2278        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2279        Some(id)
2280    }
2281}
2282
2283/// Drive TCP accepts and concurrent TLS handshakes for [`TlsListener`].
2284///
2285/// Each accepted connection's handshake runs on its own worker task under
2286/// a permit from a `max_concurrent_handshakes`-sized semaphore and a
2287/// `handshake_timeout` deadline. Completed handshakes are pushed to `tx`;
2288/// failures and timeouts are logged at DEBUG and the connection dropped.
2289/// The loop exits when the owning [`TlsListener`] is dropped.
2290async fn run_tls_acceptor(
2291    listener: TcpListener,
2292    acceptor: tokio_rustls::TlsAcceptor,
2293    default_role: String,
2294    tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2295    handshake_timeout: Duration,
2296    max_concurrent_handshakes: usize,
2297) {
2298    let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2299    loop {
2300        // Acquire the permit BEFORE accepting: at saturation, pending
2301        // connections wait in the kernel backlog instead of being accepted
2302        // and then buffered or dropped in user space.
2303        let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2304            // The semaphore is never closed; defensive exit.
2305            return;
2306        };
2307        let (stream, addr) = match listener.accept().await {
2308            Ok(pair) => pair,
2309            Err(e) => {
2310                tracing::debug!("TCP accept error: {e}");
2311                continue;
2312            }
2313        };
2314        if tx.is_closed() {
2315            // The listener was dropped (shutdown): stop accepting.
2316            return;
2317        }
2318        let acceptor = acceptor.clone();
2319        let default_role = default_role.clone();
2320        let tx = tx.clone();
2321        tokio::spawn(async move {
2322            let _permit = permit;
2323            match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2324                Ok(Ok(tls_stream)) => {
2325                    let identity =
2326                        TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2327                    let wrapped = AuthenticatedTlsStream {
2328                        inner: tls_stream,
2329                        identity,
2330                    };
2331                    // The receiver only disappears during shutdown; discard
2332                    // the completed connection quietly rather than logging.
2333                    let _ = tx.send((wrapped, addr)).await;
2334                }
2335                Ok(Err(e)) => {
2336                    tracing::debug!("TLS handshake failed from {addr}: {e}");
2337                }
2338                Err(_elapsed) => {
2339                    tracing::debug!(
2340                        "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2341                    );
2342                }
2343            }
2344        });
2345    }
2346}
2347
2348/// A TLS stream paired with the mTLS identity extracted at handshake time.
2349///
2350/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
2351/// identity travels with the connection itself. This replaces the previous
2352/// shared `MtlsIdentities` map, eliminating the
2353/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
2354/// reuse and removing the need for an LRU eviction policy.
2355///
2356/// The wrapper is `Unpin` (its inner stream is `Unpin` because
2357/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
2358/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
2359pub(crate) struct AuthenticatedTlsStream {
2360    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2361    identity: Option<AuthIdentity>,
2362}
2363
2364impl AuthenticatedTlsStream {
2365    /// Returns the verified mTLS client identity, if any.
2366    #[must_use]
2367    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2368        self.identity.as_ref()
2369    }
2370}
2371
2372impl std::fmt::Debug for AuthenticatedTlsStream {
2373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2374        f.debug_struct("AuthenticatedTlsStream")
2375            .field("identity", &self.identity.as_ref().map(|id| &id.name))
2376            .finish_non_exhaustive()
2377    }
2378}
2379
2380impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2381    fn poll_read(
2382        mut self: Pin<&mut Self>,
2383        cx: &mut std::task::Context<'_>,
2384        buf: &mut tokio::io::ReadBuf<'_>,
2385    ) -> std::task::Poll<std::io::Result<()>> {
2386        Pin::new(&mut self.inner).poll_read(cx, buf)
2387    }
2388}
2389
2390impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2391    fn poll_write(
2392        mut self: Pin<&mut Self>,
2393        cx: &mut std::task::Context<'_>,
2394        buf: &[u8],
2395    ) -> std::task::Poll<std::io::Result<usize>> {
2396        Pin::new(&mut self.inner).poll_write(cx, buf)
2397    }
2398
2399    fn poll_flush(
2400        mut self: Pin<&mut Self>,
2401        cx: &mut std::task::Context<'_>,
2402    ) -> std::task::Poll<std::io::Result<()>> {
2403        Pin::new(&mut self.inner).poll_flush(cx)
2404    }
2405
2406    fn poll_shutdown(
2407        mut self: Pin<&mut Self>,
2408        cx: &mut std::task::Context<'_>,
2409    ) -> std::task::Poll<std::io::Result<()>> {
2410        Pin::new(&mut self.inner).poll_shutdown(cx)
2411    }
2412
2413    fn poll_write_vectored(
2414        mut self: Pin<&mut Self>,
2415        cx: &mut std::task::Context<'_>,
2416        bufs: &[std::io::IoSlice<'_>],
2417    ) -> std::task::Poll<std::io::Result<usize>> {
2418        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2419    }
2420
2421    fn is_write_vectored(&self) -> bool {
2422        self.inner.is_write_vectored()
2423    }
2424}
2425
2426impl axum::serve::Listener for TlsListener {
2427    type Io = AuthenticatedTlsStream;
2428    type Addr = SocketAddr;
2429
2430    /// Yield the next fully-handshaken TLS connection.
2431    ///
2432    /// Cancel-safe: this is a plain `mpsc::Receiver::recv`, so cancelling
2433    /// the future (axum selects it against graceful shutdown) never loses
2434    /// a connection.
2435    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2436        if let Some(pair) = self.rx.recv().await {
2437            return pair;
2438        }
2439        // The channel only closes if the acceptor task terminated, which
2440        // means the TcpListener is gone and the OS already refuses new
2441        // connections. `Listener::accept` is infallible and panicking is
2442        // forbidden, so park forever: existing connections keep being
2443        // served and graceful shutdown still completes.
2444        tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2445        std::future::pending().await
2446    }
2447
2448    fn local_addr(&self) -> std::io::Result<Self::Addr> {
2449        Ok(self.local_addr)
2450    }
2451}
2452
2453impl Drop for TlsListener {
2454    fn drop(&mut self) {
2455        // Stop accepting immediately and release the bound port. In-flight
2456        // handshake workers notice the closed channel and exit quietly.
2457        self.acceptor_task.abort();
2458    }
2459}
2460
2461fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2462    use rustls::pki_types::pem::PemObject;
2463    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2464        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2465        .collect::<Result<_, _>>()
2466        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2467    anyhow::ensure!(
2468        !certs.is_empty(),
2469        "no certificates found in {}",
2470        path.display()
2471    );
2472    Ok(certs)
2473}
2474
2475fn load_client_auth_roots(
2476    path: &Path,
2477) -> anyhow::Result<(
2478    Vec<rustls::pki_types::CertificateDer<'static>>,
2479    Arc<RootCertStore>,
2480)> {
2481    let ca_certs = load_certs(path)?;
2482    let mut root_store = RootCertStore::empty();
2483    for cert in &ca_certs {
2484        root_store
2485            .add(cert.clone())
2486            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2487    }
2488
2489    Ok((ca_certs, Arc::new(root_store)))
2490}
2491
2492fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2493    use rustls::pki_types::pem::PemObject;
2494    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2495        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2496}
2497
2498#[allow(
2499    clippy::unused_async,
2500    reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2501)]
2502async fn healthz() -> impl IntoResponse {
2503    axum::Json(serde_json::json!({
2504        "status": "ok",
2505    }))
2506}
2507
2508/// Build the `/version` JSON payload for a given server name and version.
2509///
2510/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
2511/// read at compile time from the `RMCP_SERVER_KIT_BUILD_SHA`,
2512/// `RMCP_SERVER_KIT_BUILD_TIME`, and `RMCP_SERVER_KIT_RUSTC_VERSION` env
2513/// vars. Unset values resolve to `"unknown"`.
2514fn version_payload(name: &str, version: &str) -> serde_json::Value {
2515    serde_json::json!({
2516        "name": name,
2517        "version": version,
2518        "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2519        "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2520        "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2521        "mcpx_version": env!("CARGO_PKG_VERSION"),
2522    })
2523}
2524
2525/// Pre-serialize the `/version` payload to immutable bytes.
2526///
2527/// This is called once at router-build time so per-request handling can
2528/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
2529/// [`serde_json::Value`] on every hit.
2530///
2531/// Serialization of a flat `serde_json::Value` of static-string fields
2532/// cannot fail in practice; the fallback to `b"{}"` exists only to
2533/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
2534fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2535    let value = version_payload(name, version);
2536    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2537}
2538
2539async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2540    let status = check().await;
2541    let ready = status
2542        .get("ready")
2543        .and_then(serde_json::Value::as_bool)
2544        .unwrap_or(false);
2545    let code = if ready {
2546        axum::http::StatusCode::OK
2547    } else {
2548        axum::http::StatusCode::SERVICE_UNAVAILABLE
2549    };
2550    (code, axum::Json(status))
2551}
2552
2553/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2554///
2555/// On non-Unix platforms, only SIGINT is handled.
2556async fn shutdown_signal() {
2557    let ctrl_c = tokio::signal::ctrl_c();
2558
2559    #[cfg(unix)]
2560    {
2561        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2562            Ok(mut term) => {
2563                tokio::select! {
2564                    _ = ctrl_c => {}
2565                    _ = term.recv() => {}
2566                }
2567            }
2568            Err(e) => {
2569                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2570                ctrl_c.await.ok();
2571            }
2572        }
2573    }
2574
2575    #[cfg(not(unix))]
2576    {
2577        ctrl_c.await.ok();
2578    }
2579}
2580
2581// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2582
2583/// Middleware that validates the `Origin` header on incoming HTTP requests.
2584///
2585/// Record HTTP request metrics (method, path, status, duration).
2586#[cfg(feature = "metrics")]
2587async fn metrics_middleware(
2588    metrics: Arc<crate::metrics::McpMetrics>,
2589    req: Request<Body>,
2590    next: Next,
2591) -> axum::response::Response {
2592    let method = req.method().to_string();
2593    let path = req.uri().path().to_owned();
2594    let start = std::time::Instant::now();
2595
2596    let response = next.run(req).await;
2597
2598    let status = response.status().as_u16().to_string();
2599    let duration = start.elapsed().as_secs_f64();
2600
2601    metrics
2602        .http_requests_total
2603        .with_label_values(&[&method, &path, &status])
2604        .inc();
2605    metrics
2606        .http_request_duration_seconds
2607        .with_label_values(&[&method, &path])
2608        .observe(duration);
2609
2610    response
2611}
2612
2613/// OWASP security header hardening applied to every response.
2614///
2615/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2616/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2617/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2618/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2619/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2620///
2621/// Each header's value can be customised via [`SecurityHeadersConfig`]
2622/// on [`McpServerConfig`]. See that type for the three-state semantic
2623/// (`None` = default, `Some("")` = omit, `Some(v)` = override).
2624async fn security_headers_middleware(
2625    is_tls: bool,
2626    cfg: Arc<SecurityHeadersConfig>,
2627    req: Request<Body>,
2628    next: Next,
2629) -> axum::response::Response {
2630    use axum::http::{HeaderName, header};
2631
2632    let mut resp = next.run(req).await;
2633    let headers = resp.headers_mut();
2634
2635    // Strip server identity headers to reduce information leakage.
2636    headers.remove(header::SERVER);
2637    headers.remove(HeaderName::from_static("x-powered-by"));
2638
2639    apply_security_header(
2640        headers,
2641        header::X_CONTENT_TYPE_OPTIONS,
2642        cfg.x_content_type_options.as_deref(),
2643        "nosniff",
2644    );
2645    apply_security_header(
2646        headers,
2647        header::X_FRAME_OPTIONS,
2648        cfg.x_frame_options.as_deref(),
2649        "deny",
2650    );
2651    apply_security_header(
2652        headers,
2653        header::CACHE_CONTROL,
2654        cfg.cache_control.as_deref(),
2655        "no-store, max-age=0",
2656    );
2657    apply_security_header(
2658        headers,
2659        header::REFERRER_POLICY,
2660        cfg.referrer_policy.as_deref(),
2661        "no-referrer",
2662    );
2663    apply_security_header(
2664        headers,
2665        HeaderName::from_static("cross-origin-opener-policy"),
2666        cfg.cross_origin_opener_policy.as_deref(),
2667        "same-origin",
2668    );
2669    apply_security_header(
2670        headers,
2671        HeaderName::from_static("cross-origin-resource-policy"),
2672        cfg.cross_origin_resource_policy.as_deref(),
2673        "same-origin",
2674    );
2675    apply_security_header(
2676        headers,
2677        HeaderName::from_static("cross-origin-embedder-policy"),
2678        cfg.cross_origin_embedder_policy.as_deref(),
2679        "require-corp",
2680    );
2681    apply_security_header(
2682        headers,
2683        HeaderName::from_static("permissions-policy"),
2684        cfg.permissions_policy.as_deref(),
2685        "accelerometer=(), camera=(), geolocation=(), microphone=()",
2686    );
2687    apply_security_header(
2688        headers,
2689        HeaderName::from_static("x-permitted-cross-domain-policies"),
2690        cfg.x_permitted_cross_domain_policies.as_deref(),
2691        "none",
2692    );
2693    apply_security_header(
2694        headers,
2695        HeaderName::from_static("content-security-policy"),
2696        cfg.content_security_policy.as_deref(),
2697        "default-src 'none'; frame-ancestors 'none'",
2698    );
2699    apply_security_header(
2700        headers,
2701        HeaderName::from_static("x-dns-prefetch-control"),
2702        cfg.x_dns_prefetch_control.as_deref(),
2703        "off",
2704    );
2705
2706    if is_tls {
2707        apply_security_header(
2708            headers,
2709            header::STRICT_TRANSPORT_SECURITY,
2710            cfg.strict_transport_security.as_deref(),
2711            "max-age=63072000; includeSubDomains",
2712        );
2713    }
2714
2715    resp
2716}
2717
2718/// Set a single security header on the response, honouring the
2719/// three-state override semantic (None = default, Some("") = omit,
2720/// Some(value) = override).
2721///
2722/// Defence-in-depth: if an override value somehow reaches this point
2723/// despite [`validate_security_headers`] having approved it (e.g. a
2724/// runtime mutation on a non-`Validated` field), we log at error level
2725/// and fall back to the static default rather than panicking. The
2726/// `Validated<McpServerConfig>` type makes that path unreachable in
2727/// well-typed code paths.
2728fn apply_security_header(
2729    headers: &mut axum::http::HeaderMap,
2730    name: axum::http::HeaderName,
2731    override_value: Option<&str>,
2732    default: &'static str,
2733) {
2734    use axum::http::HeaderValue;
2735
2736    match override_value {
2737        None => {
2738            headers.insert(name, HeaderValue::from_static(default));
2739        }
2740        Some("") => {
2741            // Operator explicitly opted out of this header.
2742        }
2743        Some(v) => match HeaderValue::from_str(v) {
2744            Ok(hv) => {
2745                headers.insert(name, hv);
2746            }
2747            Err(err) => {
2748                tracing::error!(
2749                    header = %name,
2750                    error = %err,
2751                    "invalid security header override reached middleware; using default"
2752                );
2753                headers.insert(name, HeaderValue::from_static(default));
2754            }
2755        },
2756    }
2757}
2758
2759/// Validate every non-empty entry in a [`SecurityHeadersConfig`].
2760///
2761/// - `None` and `Some("")` are accepted unconditionally (use-default and
2762///   omit, respectively).
2763/// - `Some(v)` is rejected if `axum::http::HeaderValue::from_str(v)` fails.
2764/// - `strict_transport_security` additionally rejects any value
2765///   containing `preload` (case-insensitive). Operators who genuinely
2766///   want to commit to the HSTS preload list must do so via a future
2767///   explicit `with_hsts_preload(true)` builder, not by smuggling
2768///   `preload` through this knob.
2769fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2770    use axum::http::HeaderValue;
2771
2772    let fields: &[(&str, Option<&str>)] = &[
2773        (
2774            "x_content_type_options",
2775            cfg.x_content_type_options.as_deref(),
2776        ),
2777        ("x_frame_options", cfg.x_frame_options.as_deref()),
2778        ("cache_control", cfg.cache_control.as_deref()),
2779        ("referrer_policy", cfg.referrer_policy.as_deref()),
2780        (
2781            "cross_origin_opener_policy",
2782            cfg.cross_origin_opener_policy.as_deref(),
2783        ),
2784        (
2785            "cross_origin_resource_policy",
2786            cfg.cross_origin_resource_policy.as_deref(),
2787        ),
2788        (
2789            "cross_origin_embedder_policy",
2790            cfg.cross_origin_embedder_policy.as_deref(),
2791        ),
2792        ("permissions_policy", cfg.permissions_policy.as_deref()),
2793        (
2794            "x_permitted_cross_domain_policies",
2795            cfg.x_permitted_cross_domain_policies.as_deref(),
2796        ),
2797        (
2798            "content_security_policy",
2799            cfg.content_security_policy.as_deref(),
2800        ),
2801        (
2802            "x_dns_prefetch_control",
2803            cfg.x_dns_prefetch_control.as_deref(),
2804        ),
2805        (
2806            "strict_transport_security",
2807            cfg.strict_transport_security.as_deref(),
2808        ),
2809    ];
2810
2811    for (field, value) in fields {
2812        let Some(v) = value else { continue };
2813        if v.is_empty() {
2814            continue;
2815        }
2816        if let Err(err) = HeaderValue::from_str(v) {
2817            return Err(McpxError::Config(format!(
2818                "invalid security_headers.{field}: {err}"
2819            )));
2820        }
2821    }
2822
2823    if let Some(v) = cfg.strict_transport_security.as_deref()
2824        && !v.is_empty()
2825        && v.to_ascii_lowercase().contains("preload")
2826    {
2827        return Err(McpxError::Config(format!(
2828            "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2829             HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2830        )));
2831    }
2832
2833    Ok(())
2834}
2835
2836/// Append RFC 6749 §5.1 / RFC 6750 §5.4 cache and `Vary` headers required
2837/// on OAuth token-issuing responses.
2838///
2839/// `Cache-Control: no-store, max-age=0` is already applied globally by
2840/// [`security_headers_middleware`]; this middleware adds:
2841///
2842/// - `Pragma: no-cache` -- mandated by RFC 6749 §5.1 for HTTP/1.0 caches.
2843/// - `Vary: Authorization` -- mandated by RFC 6750 §5.4 for endpoints
2844///   whose response depends on the `Authorization` header.
2845///
2846/// Applied only to the OAuth proxy token-class endpoints (`/token`,
2847/// `/register`, `/introspect`, `/revoke`). `Vary` is appended (not
2848/// inserted) so any `Vary` value already present (e.g. `Accept-Encoding`
2849/// from a compression layer, or `Origin` from a CORS layer) is preserved.
2850#[cfg(feature = "oauth")]
2851async fn oauth_token_cache_headers_middleware(
2852    req: Request<Body>,
2853    next: Next,
2854) -> axum::response::Response {
2855    use axum::http::{HeaderValue, header};
2856
2857    let mut resp = next.run(req).await;
2858    let headers = resp.headers_mut();
2859    headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2860    headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2861    resp
2862}
2863
2864/// Normalize peer-address request extensions across listener branches.
2865///
2866/// The make-service installs `ConnectInfo<SocketAddr>` on the plain
2867/// listener but `ConnectInfo<TlsConnInfo>` on the TLS listener (the
2868/// latter additionally carries the connection-bound mTLS identity and
2869/// stays `pub(crate)` — see the anti-aliasing rationale on
2870/// [`TlsConnInfo`]). Application routes — in particular those merged via
2871/// [`McpServerConfig::with_extra_router`], which bypass the auth
2872/// middleware and its private fallback — could therefore not read the
2873/// peer address under TLS.
2874///
2875/// This middleware makes both branches look identical to every route and
2876/// inner middleware:
2877///
2878/// 1. mirrors the TLS peer address into `ConnectInfo<SocketAddr>` when
2879///    (and only when) it is absent, so stock axum-ecosystem extractors
2880///    work unmodified, and
2881/// 2. inserts the framework-owned [`PeerAddr`] extension on both
2882///    branches.
2883///
2884/// Precedence mirrors the auth middleware: an existing
2885/// `ConnectInfo<SocketAddr>` always wins and is never overwritten. The
2886/// peer address is deliberately not logged here.
2887async fn normalize_peer_addr_middleware(
2888    mut req: Request<Body>,
2889    next: Next,
2890) -> axum::response::Response {
2891    let direct = req
2892        .extensions()
2893        .get::<ConnectInfo<SocketAddr>>()
2894        .map(|ci| ci.0);
2895    let from_tls = req
2896        .extensions()
2897        .get::<ConnectInfo<TlsConnInfo>>()
2898        .map(|ci| ci.0.addr);
2899    if let Some(addr) = direct.or(from_tls) {
2900        if direct.is_none() {
2901            req.extensions_mut().insert(ConnectInfo(addr));
2902        }
2903        req.extensions_mut().insert(PeerAddr::new(addr));
2904    }
2905    next.run(req).await
2906}
2907
2908/// Per-IP rate limiter for `extra_router` routes, keyed by the direct
2909/// socket peer address. Same memory-bounded machinery as the tool
2910/// limiter ([`crate::rbac`]).
2911pub(crate) type ExtraRouteRateLimiter = BoundedKeyedLimiter<IpAddr>;
2912
2913/// Cap on distinct source IPs tracked by the extra-route limiter.
2914/// Mirrors the tool limiter's bound: memory stays bounded at saturation
2915/// via idle-prune + LRU eviction, at the cost of shared-fate fairness
2916/// under key spray (an attacker churning many IPs can reset quieter
2917/// legitimate IPs to fresh buckets).
2918const EXTRA_ROUTE_MAX_TRACKED_KEYS: usize = 10_000;
2919
2920/// Idle-eviction window for the extra-route limiter (15 minutes),
2921/// mirroring the tool limiter.
2922const EXTRA_ROUTE_IDLE_EVICTION: Duration = Duration::from_mins(15);
2923
2924/// Build the per-IP limiter for `extra_router` routes.
2925///
2926/// `per_minute` is validated to be nonzero by
2927/// [`McpServerConfig::validate`]; the underlying `.max(1)` clamp in
2928/// [`BoundedKeyedLimiter::with_per_minute`] is therefore unreachable.
2929fn build_extra_route_rate_limiter(per_minute: u32) -> Arc<ExtraRouteRateLimiter> {
2930    Arc::new(BoundedKeyedLimiter::with_per_minute(
2931        per_minute,
2932        EXTRA_ROUTE_MAX_TRACKED_KEYS,
2933        EXTRA_ROUTE_IDLE_EVICTION,
2934    ))
2935}
2936
2937/// Per-IP rate limit middleware for `extra_router` routes.
2938///
2939/// Applied to the application-supplied router **before** it is merged
2940/// into the top-level router, so it wraps exactly the extra routes
2941/// (and their fallback, if any) and nothing else — `/mcp`, health,
2942/// admin, and OAuth endpoints are never affected. Outer layers (origin
2943/// check, peer-address normalization, security headers, metrics) still
2944/// wrap these routes and run first, so both `ConnectInfo` forms are
2945/// populated by the time this middleware reads them.
2946///
2947/// Semantics mirror the tool/auth limiters exactly: keyed by the
2948/// direct peer `IpAddr` (no `X-Forwarded-For`), fail-open when no peer
2949/// address is present (cannot happen under [`serve`]), and on limit a
2950/// plain-text 429 via [`McpxError::RateLimited`] — no `Retry-After`
2951/// header, consistent with every other limiter in the crate.
2952async fn extra_route_rate_limit_middleware(
2953    limiter: Arc<ExtraRouteRateLimiter>,
2954    req: Request<Body>,
2955    next: Next,
2956) -> axum::response::Response {
2957    let peer_ip: Option<IpAddr> = req
2958        .extensions()
2959        .get::<ConnectInfo<SocketAddr>>()
2960        .map(|ci| ci.0.ip())
2961        .or_else(|| {
2962            req.extensions()
2963                .get::<ConnectInfo<TlsConnInfo>>()
2964                .map(|ci| ci.0.addr.ip())
2965        });
2966    if let Some(ip) = peer_ip
2967        && limiter.check_key(&ip).is_err()
2968    {
2969        tracing::warn!(%ip, "extra route request rate limited");
2970        return McpxError::RateLimited(
2971            "too many requests to application routes from this source".into(),
2972        )
2973        .into_response();
2974    }
2975    next.run(req).await
2976}
2977
2978/// Per the MCP spec: if the Origin header is present and its value is not in
2979/// the allowed list, respond with 403 Forbidden. Requests without an Origin
2980/// header are allowed through (e.g. non-browser clients like curl, SDKs).
2981async fn origin_check_middleware(
2982    allowed: Arc<[String]>,
2983    log_request_headers: bool,
2984    req: Request<Body>,
2985    next: Next,
2986) -> axum::response::Response {
2987    let method = req.method().clone();
2988    let path = req.uri().path().to_owned();
2989
2990    log_incoming_request(&method, &path, req.headers(), log_request_headers);
2991
2992    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2993        let origin_str = origin.to_str().unwrap_or("");
2994        if !allowed.iter().any(|a| a == origin_str) {
2995            tracing::warn!(
2996                origin = origin_str,
2997                %method,
2998                %path,
2999                allowed = ?&*allowed,
3000                "rejected request: Origin not allowed"
3001            );
3002            return (
3003                axum::http::StatusCode::FORBIDDEN,
3004                "Forbidden: Origin not allowed",
3005            )
3006                .into_response();
3007        }
3008    }
3009    next.run(req).await
3010}
3011
3012/// Emit a DEBUG log for an incoming request, optionally including the full
3013/// (redacted) header set.
3014fn log_incoming_request(
3015    method: &axum::http::Method,
3016    path: &str,
3017    headers: &axum::http::HeaderMap,
3018    log_request_headers: bool,
3019) {
3020    if log_request_headers {
3021        tracing::debug!(
3022            %method,
3023            %path,
3024            headers = %format_request_headers_for_log(headers),
3025            "incoming request"
3026        );
3027    } else {
3028        tracing::debug!(%method, %path, "incoming request");
3029    }
3030}
3031
3032fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
3033    headers
3034        .iter()
3035        .map(|(k, v)| {
3036            let name = k.as_str();
3037            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
3038                format!("{name}: [REDACTED]")
3039            } else {
3040                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
3041            }
3042        })
3043        .collect::<Vec<_>>()
3044        .join(", ")
3045}
3046
3047// -- stdio transport --
3048
3049/// Serve an MCP server over stdin/stdout (stdio transport).
3050///
3051/// # Security warnings
3052///
3053/// - **No authentication**: the parent process has full, unrestricted access.
3054/// - **No RBAC**: all tools are available regardless of policy.
3055/// - **No TLS**: messages travel over OS pipes in plaintext.
3056/// - **Single client**: only the parent process can connect.
3057/// - **No Origin validation**: not applicable to stdio.
3058///
3059/// Use this only when the MCP client spawns the server as a trusted subprocess
3060/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
3061/// use `serve()` (Streamable HTTP) instead.
3062///
3063/// # Errors
3064///
3065/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
3066/// transport disconnects unexpectedly.
3067// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
3068// macro expansion in this 18-line function (info/warn/info + two matches).
3069// There is nothing meaningful to extract; the allow stays.
3070#[allow(
3071    clippy::cognitive_complexity,
3072    reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
3073)]
3074pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
3075where
3076    H: ServerHandler + 'static,
3077{
3078    use rmcp::ServiceExt as _;
3079
3080    tracing::info!("stdio transport: serving on stdin/stdout");
3081    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
3082
3083    let transport = rmcp::transport::io::stdio();
3084
3085    let service = handler
3086        .serve(transport)
3087        .await
3088        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
3089
3090    if let Err(e) = service.waiting().await {
3091        tracing::warn!(error = %e, "stdio session ended with error");
3092    }
3093    tracing::info!("stdio session ended");
3094    Ok(())
3095}
3096
3097#[cfg(test)]
3098mod tests {
3099    #![allow(
3100        clippy::unwrap_used,
3101        clippy::expect_used,
3102        clippy::panic,
3103        clippy::indexing_slicing,
3104        clippy::unwrap_in_result,
3105        clippy::print_stdout,
3106        clippy::print_stderr,
3107        deprecated,
3108        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
3109    )]
3110    use std::{sync::Arc, time::Duration};
3111
3112    use axum::{
3113        body::Body,
3114        http::{Request, StatusCode, header},
3115        response::IntoResponse,
3116    };
3117    use http_body_util::BodyExt;
3118    use tower::ServiceExt as _;
3119
3120    use super::*;
3121
3122    // -- McpServerConfig --
3123
3124    #[test]
3125    fn server_config_new_defaults() {
3126        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
3127        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
3128        assert_eq!(cfg.name, "test-server");
3129        assert_eq!(cfg.version, "1.0.0");
3130        assert!(cfg.tls_cert_path.is_none());
3131        assert!(cfg.tls_key_path.is_none());
3132        assert!(cfg.auth.is_none());
3133        assert!(cfg.rbac.is_none());
3134        assert!(cfg.allowed_origins.is_empty());
3135        assert!(cfg.tool_rate_limit.is_none());
3136        assert!(cfg.readiness_check.is_none());
3137        assert_eq!(cfg.max_request_body, 1024 * 1024);
3138        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
3139        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
3140        assert!(!cfg.log_request_headers);
3141        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
3142        assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
3143    }
3144
3145    #[test]
3146    fn tls_handshake_builders_set_fields() {
3147        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3148            .with_tls_handshake_timeout(Duration::from_secs(3))
3149            .with_max_concurrent_tls_handshakes(64);
3150        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
3151        assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
3152    }
3153
3154    #[test]
3155    fn validate_rejects_zero_tls_handshake_timeout() {
3156        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3157            .with_tls_handshake_timeout(Duration::ZERO);
3158        let err = cfg.validate().expect_err("zero handshake timeout");
3159        assert!(err.to_string().contains("tls_handshake_timeout"));
3160    }
3161
3162    #[test]
3163    fn validate_rejects_zero_max_concurrent_tls_handshakes() {
3164        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3165            .with_max_concurrent_tls_handshakes(0);
3166        let err = cfg.validate().expect_err("zero handshake concurrency");
3167        assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
3168    }
3169
3170    #[test]
3171    fn validate_consumes_and_proves() {
3172        // Valid config -> Validated wrapper, original is consumed.
3173        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3174        let validated = cfg.validate().expect("valid config");
3175        // as_inner() gives read-only access to inner fields.
3176        assert_eq!(validated.as_inner().name, "test-server");
3177        // into_inner recovers the raw value.
3178        let raw = validated.into_inner();
3179        assert_eq!(raw.name, "test-server");
3180
3181        // Invalid config (zero max_request_body) -> Err.
3182        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
3183        bad.max_request_body = 0;
3184        assert!(bad.validate().is_err(), "zero body cap must fail validate");
3185    }
3186
3187    #[test]
3188    fn validate_rejects_zero_max_concurrent_requests() {
3189        let cfg =
3190            McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
3191        let err = cfg.validate().expect_err("zero concurrency cap must fail");
3192        assert!(
3193            format!("{err}").contains("max_concurrent_requests"),
3194            "error should mention max_concurrent_requests, got: {err}"
3195        );
3196    }
3197
3198    #[test]
3199    fn validate_rejects_zero_max_tracked_keys() {
3200        // Defaults mirror auth::default_max_attempts / default_idle_eviction
3201        // (module-private in auth.rs); spelled out here for review clarity.
3202        let rl = crate::auth::RateLimitConfig {
3203            max_attempts_per_minute: 30,
3204            pre_auth_max_per_minute: None,
3205            max_tracked_keys: 0,
3206            idle_eviction: Duration::from_secs(15 * 60),
3207        };
3208        let auth_cfg = AuthConfig {
3209            enabled: true,
3210            api_keys: Vec::new(),
3211            mtls: None,
3212            rate_limit: Some(rl),
3213            #[cfg(feature = "oauth")]
3214            oauth: None,
3215        };
3216        let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
3217        let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
3218        assert!(
3219            format!("{err}").contains("max_tracked_keys"),
3220            "error should mention max_tracked_keys, got: {err}"
3221        );
3222    }
3223
3224    #[test]
3225    fn derive_allowed_hosts_includes_public_host() {
3226        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
3227        assert!(
3228            hosts.iter().any(|h| h == "mcp.example.com"),
3229            "public_url host must be allowed"
3230        );
3231    }
3232
3233    #[test]
3234    fn derive_allowed_hosts_includes_bind_authority() {
3235        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
3236        assert!(
3237            hosts.iter().any(|h| h == "127.0.0.1"),
3238            "bind host must be allowed"
3239        );
3240        assert!(
3241            hosts.iter().any(|h| h == "127.0.0.1:8080"),
3242            "bind authority must be allowed"
3243        );
3244    }
3245
3246    // -- healthz --
3247
3248    #[tokio::test]
3249    async fn healthz_returns_ok_json() {
3250        let resp = healthz().await.into_response();
3251        assert_eq!(resp.status(), StatusCode::OK);
3252        let body = resp.into_body().collect().await.unwrap().to_bytes();
3253        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3254        assert_eq!(json["status"], "ok");
3255        assert!(
3256            json.get("name").is_none(),
3257            "healthz must not expose server name"
3258        );
3259        assert!(
3260            json.get("version").is_none(),
3261            "healthz must not expose version"
3262        );
3263    }
3264
3265    // -- readyz --
3266
3267    #[tokio::test]
3268    async fn readyz_returns_ok_when_ready() {
3269        let check: ReadinessCheck =
3270            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
3271        let resp = readyz(check).await.into_response();
3272        assert_eq!(resp.status(), StatusCode::OK);
3273        let body = resp.into_body().collect().await.unwrap().to_bytes();
3274        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
3275        assert_eq!(json["ready"], true);
3276        assert!(
3277            json.get("name").is_none(),
3278            "readyz must not expose server name"
3279        );
3280        assert!(
3281            json.get("version").is_none(),
3282            "readyz must not expose version"
3283        );
3284        assert_eq!(json["db"], "connected");
3285    }
3286
3287    #[tokio::test]
3288    async fn readyz_returns_503_when_not_ready() {
3289        let check: ReadinessCheck =
3290            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3291        let resp = readyz(check).await.into_response();
3292        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3293    }
3294
3295    #[tokio::test]
3296    async fn readyz_returns_503_when_ready_missing() {
3297        let check: ReadinessCheck =
3298            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3299        let resp = readyz(check).await.into_response();
3300        // Missing "ready" field defaults to false -> 503
3301        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3302    }
3303
3304    // -- normalize_peer_addr_middleware / PeerAddr --
3305
3306    /// Build a test router that reports the request's peer-address
3307    /// extensions as `"<ConnectInfo>|<PeerAddr>"` (empty when absent).
3308    fn peer_probe_router() -> axum::Router {
3309        async fn probe(req: Request<Body>) -> String {
3310            let ci = req
3311                .extensions()
3312                .get::<ConnectInfo<SocketAddr>>()
3313                .map(|c| c.0.to_string())
3314                .unwrap_or_default();
3315            let pa = req
3316                .extensions()
3317                .get::<PeerAddr>()
3318                .map(|p| p.addr.to_string())
3319                .unwrap_or_default();
3320            format!("{ci}|{pa}")
3321        }
3322        axum::Router::new()
3323            .route("/probe", axum::routing::get(probe))
3324            .layer(axum::middleware::from_fn(normalize_peer_addr_middleware))
3325    }
3326
3327    async fn body_string(resp: axum::response::Response) -> String {
3328        let bytes = resp.into_body().collect().await.unwrap().to_bytes();
3329        String::from_utf8(bytes.to_vec()).unwrap()
3330    }
3331
3332    #[tokio::test]
3333    async fn normalize_preserves_existing_connect_info_and_mirrors_peer_addr() {
3334        // Precedence proof: when both extensions exist with DIFFERENT
3335        // addresses, ConnectInfo<SocketAddr> wins and is never overwritten.
3336        let plain: SocketAddr = "10.0.0.1:1111".parse().unwrap();
3337        let tls: SocketAddr = "10.0.0.2:2222".parse().unwrap();
3338        let req = Request::builder()
3339            .uri("/probe")
3340            .extension(ConnectInfo(plain))
3341            .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3342            .body(Body::empty())
3343            .unwrap();
3344        let resp = peer_probe_router().oneshot(req).await.unwrap();
3345        assert_eq!(resp.status(), StatusCode::OK);
3346        assert_eq!(body_string(resp).await, format!("{plain}|{plain}"));
3347    }
3348
3349    #[tokio::test]
3350    async fn normalize_inserts_connect_info_and_peer_addr_from_tls() {
3351        let tls: SocketAddr = "192.168.1.7:50443".parse().unwrap();
3352        let req = Request::builder()
3353            .uri("/probe")
3354            .extension(ConnectInfo(TlsConnInfo::new(tls, None)))
3355            .body(Body::empty())
3356            .unwrap();
3357        let resp = peer_probe_router().oneshot(req).await.unwrap();
3358        assert_eq!(resp.status(), StatusCode::OK);
3359        assert_eq!(body_string(resp).await, format!("{tls}|{tls}"));
3360    }
3361
3362    #[tokio::test]
3363    async fn normalize_no_op_without_any_connect_info() {
3364        let req = Request::builder()
3365            .uri("/probe")
3366            .body(Body::empty())
3367            .unwrap();
3368        let resp = peer_probe_router().oneshot(req).await.unwrap();
3369        assert_eq!(resp.status(), StatusCode::OK);
3370        assert_eq!(body_string(resp).await, "|");
3371    }
3372
3373    #[tokio::test]
3374    async fn peer_addr_extractor_rejects_when_absent() {
3375        async fn h(peer: PeerAddr) -> String {
3376            peer.addr.to_string()
3377        }
3378        let app = axum::Router::new().route("/p", axum::routing::get(h));
3379        let req = Request::builder().uri("/p").body(Body::empty()).unwrap();
3380        let resp = app.oneshot(req).await.unwrap();
3381        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
3382    }
3383
3384    #[tokio::test]
3385    async fn peer_addr_extractor_returns_value_when_present() {
3386        async fn h(peer: PeerAddr) -> String {
3387            peer.addr.to_string()
3388        }
3389        let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
3390        let app = axum::Router::new().route("/p", axum::routing::get(h));
3391        let req = Request::builder()
3392            .uri("/p")
3393            .extension(PeerAddr::new(addr))
3394            .body(Body::empty())
3395            .unwrap();
3396        let resp = app.oneshot(req).await.unwrap();
3397        assert_eq!(resp.status(), StatusCode::OK);
3398        assert_eq!(body_string(resp).await, addr.to_string());
3399    }
3400
3401    #[tokio::test]
3402    async fn peer_addr_via_extension_extractor() {
3403        async fn h(axum::Extension(peer): axum::Extension<PeerAddr>) -> String {
3404            peer.addr.to_string()
3405        }
3406        let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
3407        let app = axum::Router::new().route("/p", axum::routing::get(h));
3408        let req = Request::builder()
3409            .uri("/p")
3410            .extension(PeerAddr::new(addr))
3411            .body(Body::empty())
3412            .unwrap();
3413        let resp = app.oneshot(req).await.unwrap();
3414        assert_eq!(resp.status(), StatusCode::OK);
3415        assert_eq!(body_string(resp).await, addr.to_string());
3416    }
3417
3418    // -- extra_route_rate_limit_middleware --
3419
3420    /// Probe router with the extra-route limiter installed, mirroring
3421    /// the layer-before-merge wiring in `build_app_router`.
3422    fn limited_router(per_minute: u32) -> axum::Router {
3423        let limiter = build_extra_route_rate_limiter(per_minute);
3424        axum::Router::new()
3425            .route("/limited", axum::routing::get(|| async { "ok" }))
3426            .layer(axum::middleware::from_fn(move |req, next| {
3427                let l = Arc::clone(&limiter);
3428                extra_route_rate_limit_middleware(l, req, next)
3429            }))
3430    }
3431
3432    fn limited_req(ip: &str) -> Request<Body> {
3433        let addr: SocketAddr = format!("{ip}:40000").parse().unwrap();
3434        Request::builder()
3435            .uri("/limited")
3436            .extension(ConnectInfo(addr))
3437            .body(Body::empty())
3438            .unwrap()
3439    }
3440
3441    #[tokio::test]
3442    async fn extra_route_limiter_denies_over_quota() {
3443        let app = limited_router(2);
3444        for i in 0..2 {
3445            let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3446            assert_eq!(resp.status(), StatusCode::OK, "request {i} should pass");
3447        }
3448        let resp = app.clone().oneshot(limited_req("10.1.1.1")).await.unwrap();
3449        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3450        let body = body_string(resp).await;
3451        assert!(
3452            body.contains("too many requests to application routes"),
3453            "deny body should match the limiter message, got: {body}"
3454        );
3455    }
3456
3457    #[tokio::test]
3458    async fn extra_route_limiter_isolates_keys() {
3459        let app = limited_router(2);
3460        for _ in 0..2 {
3461            let resp = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3462            assert_eq!(resp.status(), StatusCode::OK);
3463        }
3464        let exhausted = app.clone().oneshot(limited_req("10.2.2.2")).await.unwrap();
3465        assert_eq!(exhausted.status(), StatusCode::TOO_MANY_REQUESTS);
3466        // A different source IP still has a fresh bucket.
3467        let other = app.clone().oneshot(limited_req("10.3.3.3")).await.unwrap();
3468        assert_eq!(other.status(), StatusCode::OK);
3469    }
3470
3471    #[tokio::test]
3472    async fn extra_route_limiter_fails_open_without_peer() {
3473        let app = limited_router(1);
3474        for i in 0..3 {
3475            let req = Request::builder()
3476                .uri("/limited")
3477                .body(Body::empty())
3478                .unwrap();
3479            let resp = app.clone().oneshot(req).await.unwrap();
3480            assert_eq!(
3481                resp.status(),
3482                StatusCode::OK,
3483                "request {i} should fail open"
3484            );
3485        }
3486    }
3487
3488    #[tokio::test]
3489    async fn extra_route_limiter_extracts_tls_conn_info() {
3490        let app = limited_router(2);
3491        let mk = || {
3492            let addr: SocketAddr = "192.168.9.9:55555".parse().unwrap();
3493            Request::builder()
3494                .uri("/limited")
3495                .extension(ConnectInfo(TlsConnInfo::new(addr, None)))
3496                .body(Body::empty())
3497                .unwrap()
3498        };
3499        for _ in 0..2 {
3500            assert_eq!(
3501                app.clone().oneshot(mk()).await.unwrap().status(),
3502                StatusCode::OK
3503            );
3504        }
3505        let resp = app.clone().oneshot(mk()).await.unwrap();
3506        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
3507    }
3508
3509    #[test]
3510    fn validate_rejects_zero_extra_route_rate_limit() {
3511        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
3512            .with_extra_route_rate_limit(0);
3513        let err = cfg.validate().expect_err("zero extra route rate limit");
3514        assert!(err.to_string().contains("extra_route_rate_limit"));
3515    }
3516
3517    // -- origin_check_middleware --
3518
3519    /// Build a test router with origin check middleware and a simple handler.
3520    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
3521        let allowed: Arc<[String]> = Arc::from(origins);
3522        axum::Router::new()
3523            .route("/test", axum::routing::get(|| async { "ok" }))
3524            .layer(axum::middleware::from_fn(move |req, next| {
3525                let a = Arc::clone(&allowed);
3526                origin_check_middleware(a, log_request_headers, req, next)
3527            }))
3528    }
3529
3530    #[tokio::test]
3531    async fn origin_allowed_passes() {
3532        let app = origin_router(vec!["http://localhost:3000".into()], false);
3533        let req = Request::builder()
3534            .uri("/test")
3535            .header(header::ORIGIN, "http://localhost:3000")
3536            .body(Body::empty())
3537            .unwrap();
3538        let resp = app.oneshot(req).await.unwrap();
3539        assert_eq!(resp.status(), StatusCode::OK);
3540    }
3541
3542    #[tokio::test]
3543    async fn origin_rejected_returns_403() {
3544        let app = origin_router(vec!["http://localhost:3000".into()], false);
3545        let req = Request::builder()
3546            .uri("/test")
3547            .header(header::ORIGIN, "http://evil.com")
3548            .body(Body::empty())
3549            .unwrap();
3550        let resp = app.oneshot(req).await.unwrap();
3551        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3552    }
3553
3554    #[tokio::test]
3555    async fn no_origin_header_passes() {
3556        let app = origin_router(vec!["http://localhost:3000".into()], false);
3557        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3558        let resp = app.oneshot(req).await.unwrap();
3559        assert_eq!(resp.status(), StatusCode::OK);
3560    }
3561
3562    #[tokio::test]
3563    async fn empty_allowlist_rejects_any_origin() {
3564        let app = origin_router(vec![], false);
3565        let req = Request::builder()
3566            .uri("/test")
3567            .header(header::ORIGIN, "http://anything.com")
3568            .body(Body::empty())
3569            .unwrap();
3570        let resp = app.oneshot(req).await.unwrap();
3571        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3572    }
3573
3574    #[tokio::test]
3575    async fn empty_allowlist_passes_without_origin() {
3576        let app = origin_router(vec![], false);
3577        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3578        let resp = app.oneshot(req).await.unwrap();
3579        assert_eq!(resp.status(), StatusCode::OK);
3580    }
3581
3582    #[test]
3583    fn format_request_headers_redacts_sensitive_values() {
3584        let mut headers = axum::http::HeaderMap::new();
3585        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
3586        headers.insert("cookie", "sid=abc".parse().unwrap());
3587        headers.insert("x-request-id", "req-123".parse().unwrap());
3588
3589        let out = format_request_headers_for_log(&headers);
3590        assert!(out.contains("authorization: [REDACTED]"));
3591        assert!(out.contains("cookie: [REDACTED]"));
3592        assert!(out.contains("x-request-id: req-123"));
3593        assert!(!out.contains("secret-token"));
3594    }
3595
3596    // -- security_headers_middleware --
3597
3598    fn security_router(is_tls: bool) -> axum::Router {
3599        security_router_with(is_tls, SecurityHeadersConfig::default())
3600    }
3601
3602    fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
3603        let cfg = Arc::new(cfg);
3604        axum::Router::new()
3605            .route("/test", axum::routing::get(|| async { "ok" }))
3606            .layer(axum::middleware::from_fn(move |req, next| {
3607                let c = Arc::clone(&cfg);
3608                security_headers_middleware(is_tls, c, req, next)
3609            }))
3610    }
3611
3612    #[tokio::test]
3613    async fn security_headers_set_on_response() {
3614        let app = security_router(false);
3615        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3616        let resp = app.oneshot(req).await.unwrap();
3617        assert_eq!(resp.status(), StatusCode::OK);
3618
3619        let h = resp.headers();
3620        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
3621        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
3622        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
3623        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
3624        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
3625        assert_eq!(
3626            h.get("cross-origin-resource-policy").unwrap(),
3627            "same-origin"
3628        );
3629        assert_eq!(
3630            h.get("cross-origin-embedder-policy").unwrap(),
3631            "require-corp"
3632        );
3633        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
3634        assert!(
3635            h.get("permissions-policy")
3636                .unwrap()
3637                .to_str()
3638                .unwrap()
3639                .contains("camera=()"),
3640            "permissions-policy must restrict browser features"
3641        );
3642        assert_eq!(
3643            h.get("content-security-policy").unwrap(),
3644            "default-src 'none'; frame-ancestors 'none'"
3645        );
3646        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
3647        // No HSTS when TLS is off.
3648        assert!(h.get("strict-transport-security").is_none());
3649    }
3650
3651    #[tokio::test]
3652    async fn hsts_set_when_tls_enabled() {
3653        let app = security_router(true);
3654        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3655        let resp = app.oneshot(req).await.unwrap();
3656
3657        let hsts = resp.headers().get("strict-transport-security").unwrap();
3658        assert!(
3659            hsts.to_str().unwrap().contains("max-age=63072000"),
3660            "HSTS must set 2-year max-age"
3661        );
3662    }
3663
3664    // -- SecurityHeadersConfig validation + override semantics --
3665
3666    /// Build a minimal config with a custom SecurityHeadersConfig and
3667    /// drive it through `check()`. Returns the result so individual
3668    /// tests can assert on success or specific error messages.
3669    fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
3670        let cfg =
3671            McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
3672        cfg.check()
3673    }
3674
3675    #[test]
3676    fn security_headers_config_default_validates() {
3677        check_with_security_headers(SecurityHeadersConfig::default())
3678            .expect("default SecurityHeadersConfig must validate");
3679    }
3680
3681    #[test]
3682    fn security_headers_config_validate_accepts_empty_string() {
3683        // All twelve fields explicitly set to "" -> omit-everything mode.
3684        let h = SecurityHeadersConfig {
3685            x_content_type_options: Some(String::new()),
3686            x_frame_options: Some(String::new()),
3687            cache_control: Some(String::new()),
3688            referrer_policy: Some(String::new()),
3689            cross_origin_opener_policy: Some(String::new()),
3690            cross_origin_resource_policy: Some(String::new()),
3691            cross_origin_embedder_policy: Some(String::new()),
3692            permissions_policy: Some(String::new()),
3693            x_permitted_cross_domain_policies: Some(String::new()),
3694            content_security_policy: Some(String::new()),
3695            x_dns_prefetch_control: Some(String::new()),
3696            strict_transport_security: Some(String::new()),
3697        };
3698        check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
3699    }
3700
3701    #[test]
3702    fn security_headers_config_validate_rejects_bad_value() {
3703        // 0x07 (BEL) is not a valid HTTP header value char.
3704        let h = SecurityHeadersConfig {
3705            referrer_policy: Some("\u{0007}".into()),
3706            ..SecurityHeadersConfig::default()
3707        };
3708        let err = check_with_security_headers(h)
3709            .expect_err("control char in referrer_policy must reject");
3710        let msg = err.to_string();
3711        assert!(
3712            msg.contains("referrer_policy"),
3713            "error must name the offending field, got: {msg}"
3714        );
3715    }
3716
3717    #[test]
3718    fn security_headers_config_validate_rejects_hsts_preload() {
3719        let h = SecurityHeadersConfig {
3720            strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
3721            ..SecurityHeadersConfig::default()
3722        };
3723        let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
3724        let msg = err.to_string();
3725        assert!(
3726            msg.contains("strict_transport_security"),
3727            "error must name the field, got: {msg}"
3728        );
3729        assert!(
3730            msg.to_lowercase().contains("preload"),
3731            "error must mention `preload`, got: {msg}"
3732        );
3733    }
3734
3735    #[test]
3736    fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
3737        // Case-insensitive match.
3738        let h = SecurityHeadersConfig {
3739            strict_transport_security: Some("max-age=600; PRELOAD".into()),
3740            ..SecurityHeadersConfig::default()
3741        };
3742        check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3743    }
3744
3745    #[tokio::test]
3746    async fn security_headers_override_honored() {
3747        // Override X-Frame-Options to SAMEORIGIN.
3748        let h = SecurityHeadersConfig {
3749            x_frame_options: Some("SAMEORIGIN".into()),
3750            ..SecurityHeadersConfig::default()
3751        };
3752        let app = security_router_with(false, h);
3753        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3754        let resp = app.oneshot(req).await.unwrap();
3755        assert_eq!(resp.status(), StatusCode::OK);
3756
3757        let xfo = resp.headers().get("x-frame-options").unwrap();
3758        assert_eq!(xfo, "SAMEORIGIN");
3759    }
3760
3761    #[tokio::test]
3762    async fn security_headers_empty_string_omits() {
3763        // Empty string on referrer-policy -> header absent.
3764        let h = SecurityHeadersConfig {
3765            referrer_policy: Some(String::new()),
3766            ..SecurityHeadersConfig::default()
3767        };
3768        let app = security_router_with(false, h);
3769        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3770        let resp = app.oneshot(req).await.unwrap();
3771        assert_eq!(resp.status(), StatusCode::OK);
3772
3773        assert!(
3774            resp.headers().get("referrer-policy").is_none(),
3775            "Some(\"\") must omit the header"
3776        );
3777        // Other defaults should still be present.
3778        assert_eq!(
3779            resp.headers().get("x-content-type-options").unwrap(),
3780            "nosniff"
3781        );
3782    }
3783
3784    #[tokio::test]
3785    async fn security_headers_hsts_only_when_tls() {
3786        // HSTS override is irrelevant when TLS is off.
3787        let h = SecurityHeadersConfig {
3788            strict_transport_security: Some("max-age=600".into()),
3789            ..SecurityHeadersConfig::default()
3790        };
3791        let app = security_router_with(false, h);
3792        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3793        let resp = app.oneshot(req).await.unwrap();
3794        assert!(
3795            resp.headers().get("strict-transport-security").is_none(),
3796            "HSTS must remain absent on plaintext deployments even with override"
3797        );
3798    }
3799
3800    // -- oauth_token_cache_headers_middleware --
3801
3802    #[cfg(feature = "oauth")]
3803    #[tokio::test]
3804    async fn oauth_token_cache_headers_set_pragma_and_vary() {
3805        let app = axum::Router::new()
3806            .route("/token", axum::routing::post(|| async { "{}" }))
3807            .layer(axum::middleware::from_fn(
3808                oauth_token_cache_headers_middleware,
3809            ));
3810        let req = Request::builder()
3811            .method("POST")
3812            .uri("/token")
3813            .body(Body::from("{}"))
3814            .unwrap();
3815        let resp = app.oneshot(req).await.unwrap();
3816        assert_eq!(resp.status(), StatusCode::OK);
3817
3818        let h = resp.headers();
3819        assert_eq!(
3820            h.get("pragma").unwrap(),
3821            "no-cache",
3822            "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3823        );
3824        let vary_values: Vec<String> = h
3825            .get_all("vary")
3826            .iter()
3827            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3828            .collect();
3829        assert!(
3830            vary_values
3831                .iter()
3832                .any(|v| v.eq_ignore_ascii_case("Authorization")),
3833            "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3834        );
3835    }
3836
3837    #[cfg(feature = "oauth")]
3838    #[tokio::test]
3839    async fn oauth_token_cache_headers_preserve_existing_vary() {
3840        // Simulates a handler/layer that already set `Vary: Accept-Encoding`
3841        // (e.g. compression). Our middleware must APPEND, not REPLACE.
3842        let app = axum::Router::new()
3843            .route(
3844                "/token",
3845                axum::routing::post(|| async {
3846                    axum::response::Response::builder()
3847                        .header("vary", "Accept-Encoding")
3848                        .body(axum::body::Body::from("{}"))
3849                        .unwrap()
3850                }),
3851            )
3852            .layer(axum::middleware::from_fn(
3853                oauth_token_cache_headers_middleware,
3854            ));
3855        let req = Request::builder()
3856            .method("POST")
3857            .uri("/token")
3858            .body(Body::empty())
3859            .unwrap();
3860        let resp = app.oneshot(req).await.unwrap();
3861
3862        let vary: Vec<String> = resp
3863            .headers()
3864            .get_all("vary")
3865            .iter()
3866            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3867            .collect();
3868        assert!(
3869            vary.iter().any(|v| v.contains("Accept-Encoding")),
3870            "must preserve pre-existing Vary value, got {vary:?}"
3871        );
3872        assert!(
3873            vary.iter().any(|v| v.contains("Authorization")),
3874            "must append Authorization to Vary, got {vary:?}"
3875        );
3876    }
3877
3878    // -- version endpoint --
3879
3880    #[test]
3881    fn version_payload_contains_expected_fields() {
3882        let v = version_payload("my-server", "1.2.3");
3883        assert_eq!(v["name"], "my-server");
3884        assert_eq!(v["version"], "1.2.3");
3885        assert!(v["build_git_sha"].is_string());
3886        assert!(v["build_timestamp"].is_string());
3887        assert!(v["rust_version"].is_string());
3888        assert!(v["mcpx_version"].is_string());
3889    }
3890
3891    // -- concurrency limit layer --
3892
3893    #[tokio::test]
3894    async fn concurrency_limit_layer_composes_and_serves() {
3895        // We only assert the layer stack compiles and a single request
3896        // below the cap still succeeds. True back-pressure behaviour
3897        // requires a live HTTP server and is covered by integration tests.
3898        let app = axum::Router::new()
3899            .route("/ok", axum::routing::get(|| async { "ok" }))
3900            .layer(
3901                tower::ServiceBuilder::new()
3902                    .layer(axum::error_handling::HandleErrorLayer::new(
3903                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3904                    ))
3905                    .layer(tower::load_shed::LoadShedLayer::new())
3906                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3907            );
3908        let resp = app
3909            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3910            .await
3911            .unwrap();
3912        assert_eq!(resp.status(), StatusCode::OK);
3913    }
3914
3915    // -- compression layer --
3916
3917    #[tokio::test]
3918    async fn compression_layer_gzip_encodes_response() {
3919        use tower_http::compression::Predicate as _;
3920
3921        let big_body = "a".repeat(4096);
3922        let app = axum::Router::new()
3923            .route(
3924                "/big",
3925                axum::routing::get(move || {
3926                    let body = big_body.clone();
3927                    async move { body }
3928                }),
3929            )
3930            .layer(
3931                tower_http::compression::CompressionLayer::new()
3932                    .gzip(true)
3933                    .br(true)
3934                    .compress_when(
3935                        tower_http::compression::DefaultPredicate::new()
3936                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3937                    ),
3938            );
3939
3940        let req = Request::builder()
3941            .uri("/big")
3942            .header(header::ACCEPT_ENCODING, "gzip")
3943            .body(Body::empty())
3944            .unwrap();
3945        let resp = app.oneshot(req).await.unwrap();
3946        assert_eq!(resp.status(), StatusCode::OK);
3947        assert_eq!(
3948            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3949            "gzip"
3950        );
3951    }
3952
3953    // -- TlsListener handshake timeout --
3954
3955    #[tokio::test]
3956    async fn tls_handshake_timeout_reaps_idle_connections() {
3957        use tokio::io::AsyncReadExt as _;
3958
3959        let _ = rustls::crypto::ring::default_provider().install_default();
3960
3961        // Self-signed cert material on disk (TlsListener::new takes paths).
3962        let key = rcgen::KeyPair::generate().expect("generate key");
3963        let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
3964            .expect("cert params")
3965            .self_signed(&key)
3966            .expect("self-signed cert");
3967        let dir = std::env::temp_dir().join(format!(
3968            "rmcp-server-kit-hs-timeout-{}",
3969            std::time::SystemTime::now()
3970                .duration_since(std::time::UNIX_EPOCH)
3971                .expect("clock after epoch")
3972                .as_nanos()
3973        ));
3974        tokio::fs::create_dir_all(&dir).await.expect("temp dir");
3975        let cert_path = dir.join("server.crt");
3976        let key_path = dir.join("server.key");
3977        tokio::fs::write(&cert_path, cert.pem())
3978            .await
3979            .expect("write cert");
3980        tokio::fs::write(&key_path, key.serialize_pem())
3981            .await
3982            .expect("write key");
3983
3984        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
3985        let tls = TlsListener::new(
3986            listener,
3987            &cert_path,
3988            &key_path,
3989            None,
3990            None,
3991            Duration::from_millis(200),
3992            8, // custom concurrency cap: proves the plumbing end-to-end
3993        )
3994        .expect("tls listener");
3995        let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
3996
3997        // Connect and send NOTHING: the handshake worker must time out
3998        // after 200ms and drop the stream, which the client observes as
3999        // EOF or a reset well within the 2s deadline.
4000        let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
4001        let mut buf = [0_u8; 16];
4002        let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
4003            .await
4004            .expect("server must reap the idle handshake within its timeout");
4005        match read {
4006            Ok(0) | Err(_) => {} // EOF or reset: connection was dropped.
4007            Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
4008        }
4009
4010        drop(tls);
4011    }
4012}