Skip to main content

rmcp_server_kit/
transport.rs

1use std::{
2    future::Future,
3    net::SocketAddr,
4    path::{Path, PathBuf},
5    pin::Pin,
6    sync::Arc,
7    time::Duration,
8};
9
10use arc_swap::ArcSwap;
11use axum::{body::Body, extract::Request, middleware::Next, response::IntoResponse};
12use rmcp::{
13    ServerHandler,
14    transport::streamable_http_server::{
15        StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
16    },
17};
18use rustls::RootCertStore;
19use tokio::{
20    net::TcpListener,
21    sync::{Semaphore, mpsc},
22};
23use tokio_util::sync::CancellationToken;
24
25use crate::{
26    auth::{
27        AuthConfig, AuthIdentity, AuthState, MtlsConfig, TlsConnInfo, auth_middleware,
28        build_rate_limiter, extract_mtls_identity,
29    },
30    error::McpxError,
31    mtls_revocation::{self, CrlSet, DynamicClientCertVerifier},
32    rbac::{RbacPolicy, ToolRateLimiter, build_tool_rate_limiter, rbac_middleware},
33};
34
35/// Map an internal `anyhow::Error` chain into a public [`McpxError::Startup`]
36/// at the public API boundary, flattening the chain via the alternate
37/// formatter so callers see the full causal path.
38#[allow(
39    clippy::needless_pass_by_value,
40    reason = "consumed at .map_err(anyhow_to_startup) call sites; by-value matches the closure shape"
41)]
42fn anyhow_to_startup(e: anyhow::Error) -> McpxError {
43    McpxError::Startup(format!("{e:#}"))
44}
45
46/// Map a `std::io::Error` produced during server startup into a public
47/// [`McpxError::Startup`]. We deliberately do not use the [`McpxError::Io`]
48/// `From` impl here because startup-phase IO errors (bind, listener) are
49/// semantically distinct from request-time IO errors and should surface
50/// the originating operation in the message.
51#[allow(
52    clippy::needless_pass_by_value,
53    reason = "consumed at .map_err(|e| io_to_startup(...)) call sites; by-value matches the closure shape"
54)]
55fn io_to_startup(op: &str, e: std::io::Error) -> McpxError {
56    McpxError::Startup(format!("{op}: {e}"))
57}
58
59/// Async readiness check callback for the `/readyz` endpoint.
60///
61/// Returns a JSON object with at least a `"ready"` boolean.
62/// When `ready` is false, the endpoint returns HTTP 503.
63pub type ReadinessCheck =
64    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> + Send + Sync>;
65
66/// Per-header overrides for the OWASP security headers emitted by the
67/// global response middleware.
68///
69/// Each field follows a three-state semantic:
70///
71/// | Value         | Behaviour                                                |
72/// |---------------|----------------------------------------------------------|
73/// | `None`        | Use the built-in default (current behaviour).            |
74/// | `Some("")`    | **Omit** the header entirely from responses.             |
75/// | `Some(value)` | Emit `header: value`. Validated at config-load time.     |
76///
77/// All non-empty values are validated via
78/// [`axum::http::HeaderValue::from_str`] inside
79/// [`McpServerConfig::validate`]; invalid values fail fast before the
80/// server starts accepting traffic.
81///
82/// `Strict-Transport-Security` has an additional rule: the substring
83/// `preload` (case-insensitive) is rejected. Operators who want to
84/// commit to the HSTS preload list must do so via a future explicit
85/// builder method, not by smuggling it through this knob.
86#[derive(Debug, Clone, Default)]
87#[non_exhaustive]
88pub struct SecurityHeadersConfig {
89    /// Override for `X-Content-Type-Options`. Default: `nosniff`.
90    pub x_content_type_options: Option<String>,
91    /// Override for `X-Frame-Options`. Default: `deny`.
92    pub x_frame_options: Option<String>,
93    /// Override for `Cache-Control`. Default: `no-store, max-age=0`.
94    pub cache_control: Option<String>,
95    /// Override for `Referrer-Policy`. Default: `no-referrer`.
96    pub referrer_policy: Option<String>,
97    /// Override for `Cross-Origin-Opener-Policy`. Default: `same-origin`.
98    pub cross_origin_opener_policy: Option<String>,
99    /// Override for `Cross-Origin-Resource-Policy`. Default: `same-origin`.
100    pub cross_origin_resource_policy: Option<String>,
101    /// Override for `Cross-Origin-Embedder-Policy`. Default: `require-corp`.
102    pub cross_origin_embedder_policy: Option<String>,
103    /// Override for `Permissions-Policy`. Default:
104    /// `accelerometer=(), camera=(), geolocation=(), microphone=()`.
105    pub permissions_policy: Option<String>,
106    /// Override for `X-Permitted-Cross-Domain-Policies`. Default: `none`.
107    pub x_permitted_cross_domain_policies: Option<String>,
108    /// Override for `Content-Security-Policy`. Default:
109    /// `default-src 'none'; frame-ancestors 'none'`.
110    pub content_security_policy: Option<String>,
111    /// Override for `X-DNS-Prefetch-Control`. Default: `off`.
112    pub x_dns_prefetch_control: Option<String>,
113    /// Override for `Strict-Transport-Security`. Default (TLS only):
114    /// `max-age=63072000; includeSubDomains`. Only emitted when TLS is
115    /// active; the override is ignored on plaintext deployments. The
116    /// substring `preload` (any case) is rejected by the validator.
117    pub strict_transport_security: Option<String>,
118}
119
120/// Configuration for the MCP server.
121#[allow(
122    missing_debug_implementations,
123    reason = "contains callback/trait objects that don't impl Debug"
124)]
125#[allow(
126    clippy::struct_excessive_bools,
127    reason = "server configuration naturally has many boolean feature flags"
128)]
129#[non_exhaustive]
130pub struct McpServerConfig {
131    /// Socket address the MCP HTTP server binds to.
132    #[deprecated(
133        since = "0.13.0",
134        note = "use McpServerConfig::new() / with_bind_addr(); direct field access will become pub(crate) in a future major release"
135    )]
136    pub bind_addr: String,
137    /// Server name advertised via MCP `initialize`.
138    #[deprecated(
139        since = "0.13.0",
140        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
141    )]
142    pub name: String,
143    /// Server version advertised via MCP `initialize`.
144    #[deprecated(
145        since = "0.13.0",
146        note = "set via McpServerConfig::new(); direct field access will become pub(crate) in a future major release"
147    )]
148    pub version: String,
149    /// Path to the TLS certificate (PEM). Required for TLS/mTLS.
150    #[deprecated(
151        since = "0.13.0",
152        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
153    )]
154    pub tls_cert_path: Option<PathBuf>,
155    /// Path to the TLS private key (PEM). Required for TLS/mTLS.
156    #[deprecated(
157        since = "0.13.0",
158        note = "use McpServerConfig::with_tls(); direct field access will become pub(crate) in a future major release"
159    )]
160    pub tls_key_path: Option<PathBuf>,
161    /// Optional authentication config. When `Some` and `enabled`, auth
162    /// is enforced on `/mcp`. `/healthz` is always open.
163    #[deprecated(
164        since = "0.13.0",
165        note = "use McpServerConfig::with_auth(); direct field access will become pub(crate) in a future major release"
166    )]
167    pub auth: Option<AuthConfig>,
168    /// Optional RBAC policy. When present and enabled, tool calls are
169    /// checked against the policy after authentication.
170    #[deprecated(
171        since = "0.13.0",
172        note = "use McpServerConfig::with_rbac(); direct field access will become pub(crate) in a future major release"
173    )]
174    pub rbac: Option<Arc<RbacPolicy>>,
175    /// Allowed Origin values for DNS rebinding protection (MCP spec MUST).
176    /// When empty and `public_url` is set, the origin is auto-derived from
177    /// the public URL. When both are empty, only requests with no Origin
178    /// header are accepted.
179    /// Example entries: `"http://localhost:3000"`, `"https://myapp.example.com"`.
180    #[deprecated(
181        since = "0.13.0",
182        note = "use McpServerConfig::with_allowed_origins(); direct field access will become pub(crate) in a future major release"
183    )]
184    pub allowed_origins: Vec<String>,
185    /// Maximum tool invocations per source IP per minute.
186    /// When set, enforced on every `tools/call` request.
187    #[deprecated(
188        since = "0.13.0",
189        note = "use McpServerConfig::with_tool_rate_limit(); direct field access will become pub(crate) in a future major release"
190    )]
191    pub tool_rate_limit: Option<u32>,
192    /// Optional readiness probe for `/readyz`.
193    /// When `None`, `/readyz` mirrors `/healthz` (always OK).
194    #[deprecated(
195        since = "0.13.0",
196        note = "use McpServerConfig::with_readiness_check(); direct field access will become pub(crate) in a future major release"
197    )]
198    pub readiness_check: Option<ReadinessCheck>,
199    /// Maximum request body size in bytes. Default: 1 MiB.
200    /// Protects against oversized payloads causing OOM.
201    #[deprecated(
202        since = "0.13.0",
203        note = "use McpServerConfig::with_max_request_body(); direct field access will become pub(crate) in a future major release"
204    )]
205    pub max_request_body: usize,
206    /// Request processing timeout. Default: 120s.
207    /// Requests exceeding this duration receive 408 Request Timeout.
208    #[deprecated(
209        since = "0.13.0",
210        note = "use McpServerConfig::with_request_timeout(); direct field access will become pub(crate) in a future major release"
211    )]
212    pub request_timeout: Duration,
213    /// Graceful shutdown timeout. Default: 30s.
214    /// After the shutdown signal, in-flight requests have this long to finish.
215    #[deprecated(
216        since = "0.13.0",
217        note = "use McpServerConfig::with_shutdown_timeout(); direct field access will become pub(crate) in a future major release"
218    )]
219    pub shutdown_timeout: Duration,
220    /// Idle timeout for MCP sessions. Sessions with no activity for this
221    /// duration are closed automatically. Default: 20 minutes.
222    #[deprecated(
223        since = "0.13.0",
224        note = "use McpServerConfig::with_session_idle_timeout(); direct field access will become pub(crate) in a future major release"
225    )]
226    pub session_idle_timeout: Duration,
227    /// Interval for SSE keep-alive pings. Prevents proxies and load
228    /// balancers from killing idle connections. Default: 15 seconds.
229    #[deprecated(
230        since = "0.13.0",
231        note = "use McpServerConfig::with_sse_keep_alive(); direct field access will become pub(crate) in a future major release"
232    )]
233    pub sse_keep_alive: Duration,
234    /// Callback invoked once the server is built, delivering a
235    /// [`ReloadHandle`] for hot-reloading auth keys and RBAC policy
236    /// at runtime (e.g. on SIGHUP). Only useful when auth/RBAC is enabled.
237    #[deprecated(
238        since = "0.13.0",
239        note = "use McpServerConfig::with_reload_callback(); direct field access will become pub(crate) in a future major release"
240    )]
241    pub on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
242    /// Additional application-specific routes merged into the top-level
243    /// router.  These routes **bypass** the MCP auth and RBAC middleware,
244    /// so the application is responsible for its own auth on them.
245    #[deprecated(
246        since = "0.13.0",
247        note = "use McpServerConfig::with_extra_router(); direct field access will become pub(crate) in a future major release"
248    )]
249    pub extra_router: Option<axum::Router>,
250    /// Externally reachable base URL (e.g. `https://mcp.example.com`).
251    /// When set, OAuth metadata endpoints advertise this URL instead of
252    /// the listen address. Required when binding `0.0.0.0` behind a
253    /// reverse proxy or inside a container.
254    #[deprecated(
255        since = "0.13.0",
256        note = "use McpServerConfig::with_public_url(); direct field access will become pub(crate) in a future major release"
257    )]
258    pub public_url: Option<String>,
259    /// Log inbound HTTP request headers at DEBUG level.
260    /// Sensitive values remain redacted.
261    #[deprecated(
262        since = "0.13.0",
263        note = "use McpServerConfig::enable_request_header_logging(); direct field access will become pub(crate) in a future major release"
264    )]
265    pub log_request_headers: bool,
266    /// Enable gzip/br response compression on MCP responses.
267    /// Defaults to `false` to preserve existing behaviour.
268    #[deprecated(
269        since = "0.13.0",
270        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
271    )]
272    pub compression_enabled: bool,
273    /// Minimum response body size (in bytes) before compression kicks in.
274    /// Only used when `compression_enabled` is true. Default: 1024.
275    #[deprecated(
276        since = "0.13.0",
277        note = "use McpServerConfig::enable_compression(); direct field access will become pub(crate) in a future major release"
278    )]
279    pub compression_min_size: u16,
280    /// Global cap on in-flight HTTP requests across the whole server.
281    /// When `Some`, requests over the cap receive 503 Service Unavailable
282    /// via `tower::load_shed`. Default: `None` (unlimited).
283    #[deprecated(
284        since = "0.13.0",
285        note = "use McpServerConfig::with_max_concurrent_requests(); direct field access will become pub(crate) in a future major release"
286    )]
287    pub max_concurrent_requests: Option<usize>,
288    /// Enable `/admin/*` diagnostic endpoints. Requires `auth` to be
289    /// configured and `enabled`. Default: `false`.
290    #[deprecated(
291        since = "0.13.0",
292        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
293    )]
294    pub admin_enabled: bool,
295    /// RBAC role required to access admin endpoints. Default: `"admin"`.
296    #[deprecated(
297        since = "0.13.0",
298        note = "use McpServerConfig::enable_admin(); direct field access will become pub(crate) in a future major release"
299    )]
300    pub admin_role: String,
301    /// Enable Prometheus metrics endpoint on a separate listener.
302    /// Requires the `metrics` crate feature.
303    #[cfg(feature = "metrics")]
304    #[deprecated(
305        since = "0.13.0",
306        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
307    )]
308    pub metrics_enabled: bool,
309    /// Bind address for the Prometheus metrics listener. Default: `127.0.0.1:9090`.
310    #[cfg(feature = "metrics")]
311    #[deprecated(
312        since = "0.13.0",
313        note = "use McpServerConfig::with_metrics(); direct field access will become pub(crate) in a future major release"
314    )]
315    pub metrics_bind: String,
316    /// Per-header overrides for the OWASP security headers emitted by
317    /// the global response middleware. See [`SecurityHeadersConfig`]
318    /// for the three-state semantic and validation rules.
319    #[deprecated(
320        since = "1.5.0",
321        note = "use McpServerConfig::with_security_headers(); direct field access will become pub(crate) in a future major release"
322    )]
323    pub security_headers: SecurityHeadersConfig,
324    /// Per-handshake deadline on the TLS accept path. Idle or slow-loris
325    /// connections are dropped once it elapses. Default: 10s.
326    ///
327    /// Startup-only: bound at listener construction, NOT hot-reloadable
328    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
329    #[deprecated(
330        since = "1.9.0",
331        note = "use McpServerConfig::with_tls_handshake_timeout(); direct field access will become pub(crate) in a future major release"
332    )]
333    pub tls_handshake_timeout: Duration,
334    /// Cap on concurrently in-flight TLS handshakes. At saturation the
335    /// acceptor stops pulling new connections from the kernel backlog
336    /// (backpressure) instead of accepting and dropping. Default: 256.
337    ///
338    /// Startup-only: bound at listener construction, NOT hot-reloadable
339    /// via [`ReloadHandle`]. Ignored unless TLS is configured.
340    #[deprecated(
341        since = "1.9.0",
342        note = "use McpServerConfig::with_max_concurrent_tls_handshakes(); direct field access will become pub(crate) in a future major release"
343    )]
344    pub max_concurrent_tls_handshakes: usize,
345}
346
347/// Marker that wraps a value proven to satisfy its validation
348/// contract.
349///
350/// The only way to obtain `Validated<McpServerConfig>` is by calling
351/// [`McpServerConfig::validate`], which is the contract enforced at
352/// the type level by [`serve`] and [`serve_with_listener`]. The
353/// inner field is private, so downstream code cannot bypass
354/// validation by hand-constructing the wrapper.
355///
356/// Use [`Validated::as_inner`] for read-only borrowing. To mutate,
357/// recover the raw value with [`Validated::into_inner`] and
358/// re-validate.
359///
360/// # Example
361///
362/// ```no_run
363/// use rmcp_server_kit::transport::{McpServerConfig, Validated, serve};
364/// use rmcp::handler::server::ServerHandler;
365/// use rmcp::model::{ServerCapabilities, ServerInfo};
366///
367/// #[derive(Clone)]
368/// struct H;
369/// impl ServerHandler for H {
370///     fn get_info(&self) -> ServerInfo {
371///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
372///     }
373/// }
374///
375/// # async fn example() -> rmcp_server_kit::Result<()> {
376/// let config: Validated<McpServerConfig> =
377///     McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0").validate()?;
378/// serve(config, || H).await
379/// # }
380/// ```
381///
382/// Forgetting `.validate()?` is a compile error:
383///
384/// ```compile_fail
385/// use rmcp_server_kit::transport::{McpServerConfig, serve};
386/// use rmcp::handler::server::ServerHandler;
387/// use rmcp::model::{ServerCapabilities, ServerInfo};
388///
389/// #[derive(Clone)]
390/// struct H;
391/// impl ServerHandler for H {
392///     fn get_info(&self) -> ServerInfo {
393///         ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
394///     }
395/// }
396///
397/// # async fn example() -> rmcp_server_kit::Result<()> {
398/// let config = McpServerConfig::new("127.0.0.1:8080", "my-server", "0.1.0");
399/// // Missing `.validate()?` -> mismatched types: expected
400/// // `Validated<McpServerConfig>`, found `McpServerConfig`.
401/// serve(config, || H).await
402/// # }
403/// ```
404#[allow(
405    missing_debug_implementations,
406    reason = "wraps T which may not implement Debug; manual impl below avoids leaking inner contents into logs"
407)]
408pub struct Validated<T>(T);
409
410impl<T> std::fmt::Debug for Validated<T> {
411    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412        f.debug_struct("Validated").finish_non_exhaustive()
413    }
414}
415
416impl<T> Validated<T> {
417    /// Borrow the inner value.
418    #[must_use]
419    pub fn as_inner(&self) -> &T {
420        &self.0
421    }
422
423    /// Recover the raw value, discarding the validation proof.
424    ///
425    /// Re-validate before re-using the value with [`serve`] or
426    /// [`serve_with_listener`].
427    #[must_use]
428    pub fn into_inner(self) -> T {
429        self.0
430    }
431}
432
433#[allow(
434    deprecated,
435    reason = "internal builders/validators legitimately read/write the deprecated `pub` fields they were designed to manage"
436)]
437impl McpServerConfig {
438    /// Create a new server configuration with the given bind address,
439    /// server name, and version. All other fields use safe defaults.
440    ///
441    /// Use the chainable `with_*` / `enable_*` builder methods to
442    /// customize. Call [`McpServerConfig::validate`] to obtain a
443    /// [`Validated<McpServerConfig>`] proof token, which is required by
444    /// [`serve`] and [`serve_with_listener`].
445    #[must_use]
446    pub fn new(
447        bind_addr: impl Into<String>,
448        name: impl Into<String>,
449        version: impl Into<String>,
450    ) -> Self {
451        Self {
452            bind_addr: bind_addr.into(),
453            name: name.into(),
454            version: version.into(),
455            tls_cert_path: None,
456            tls_key_path: None,
457            auth: None,
458            rbac: None,
459            allowed_origins: Vec::new(),
460            tool_rate_limit: None,
461            readiness_check: None,
462            max_request_body: 1024 * 1024,
463            request_timeout: Duration::from_mins(2),
464            shutdown_timeout: Duration::from_secs(30),
465            session_idle_timeout: Duration::from_mins(20),
466            sse_keep_alive: Duration::from_secs(15),
467            on_reload_ready: None,
468            extra_router: None,
469            public_url: None,
470            log_request_headers: false,
471            compression_enabled: false,
472            compression_min_size: 1024,
473            max_concurrent_requests: None,
474            admin_enabled: false,
475            admin_role: "admin".to_owned(),
476            #[cfg(feature = "metrics")]
477            metrics_enabled: false,
478            #[cfg(feature = "metrics")]
479            metrics_bind: "127.0.0.1:9090".into(),
480            security_headers: SecurityHeadersConfig::default(),
481            tls_handshake_timeout: DEFAULT_TLS_HANDSHAKE_TIMEOUT,
482            max_concurrent_tls_handshakes: DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES,
483        }
484    }
485
486    // ---------------------------------------------------------------
487    // Builder methods (fluent, consume + return self).
488    //
489    // Each method is `#[must_use]` because dropping the returned
490    // `McpServerConfig` discards the configuration change.
491    // ---------------------------------------------------------------
492
493    /// Attach an authentication configuration. Required for
494    /// [`enable_admin`](Self::enable_admin) and any non-public deployment.
495    #[must_use]
496    pub fn with_auth(mut self, auth: AuthConfig) -> Self {
497        self.auth = Some(auth);
498        self
499    }
500
501    /// Override one or more of the OWASP security headers emitted on
502    /// every response. See [`SecurityHeadersConfig`] for the three-state
503    /// semantic (`None` = default, `Some("")` = omit, `Some(v)` =
504    /// override). Values are validated by [`Self::validate`].
505    #[must_use]
506    pub fn with_security_headers(mut self, headers: SecurityHeadersConfig) -> Self {
507        self.security_headers = headers;
508        self
509    }
510
511    /// Override the bind address (e.g. `127.0.0.1:8080`). Useful when the
512    /// final port is only known after pre-binding an ephemeral listener
513    /// (tests, dynamic-port deployments).
514    #[must_use]
515    pub fn with_bind_addr(mut self, addr: impl Into<String>) -> Self {
516        self.bind_addr = addr.into();
517        self
518    }
519
520    /// Attach an RBAC policy. Tool calls are checked against the policy
521    /// after authentication.
522    #[must_use]
523    pub fn with_rbac(mut self, rbac: Arc<RbacPolicy>) -> Self {
524        self.rbac = Some(rbac);
525        self
526    }
527
528    /// Configure TLS by providing the certificate and private key paths
529    /// (PEM). Both must be readable at startup. Without this call, the
530    /// server runs plain HTTP.
531    #[must_use]
532    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
533        self.tls_cert_path = Some(cert_path.into());
534        self.tls_key_path = Some(key_path.into());
535        self
536    }
537
538    /// Set the externally reachable base URL (e.g. `https://mcp.example.com`).
539    /// Required when binding `0.0.0.0` behind a reverse proxy or inside
540    /// a container so OAuth metadata and auto-derived origins resolve correctly.
541    #[must_use]
542    pub fn with_public_url(mut self, url: impl Into<String>) -> Self {
543        self.public_url = Some(url.into());
544        self
545    }
546
547    /// Replace the allowed Origin allow-list (DNS-rebinding protection).
548    /// When empty and [`with_public_url`](Self::with_public_url) is set,
549    /// the origin is auto-derived.
550    #[must_use]
551    pub fn with_allowed_origins<I, S>(mut self, origins: I) -> Self
552    where
553        I: IntoIterator<Item = S>,
554        S: Into<String>,
555    {
556        self.allowed_origins = origins.into_iter().map(Into::into).collect();
557        self
558    }
559
560    /// Merge an additional axum router at the top level. Routes added
561    /// here **bypass** rmcp-server-kit auth and RBAC; the application is responsible
562    /// for its own protection.
563    #[must_use]
564    pub fn with_extra_router(mut self, router: axum::Router) -> Self {
565        self.extra_router = Some(router);
566        self
567    }
568
569    /// Install an async readiness probe for `/readyz`. Without this call,
570    /// `/readyz` mirrors `/healthz` (always 200 OK).
571    #[must_use]
572    pub fn with_readiness_check(mut self, check: ReadinessCheck) -> Self {
573        self.readiness_check = Some(check);
574        self
575    }
576
577    /// Override the maximum request body (bytes). Must be `> 0`.
578    /// Default: 1 MiB.
579    #[must_use]
580    pub fn with_max_request_body(mut self, bytes: usize) -> Self {
581        self.max_request_body = bytes;
582        self
583    }
584
585    /// Override the per-request processing timeout. Default: 2 minutes.
586    #[must_use]
587    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
588        self.request_timeout = timeout;
589        self
590    }
591
592    /// Override the graceful shutdown grace period. Default: 30 seconds.
593    #[must_use]
594    pub fn with_shutdown_timeout(mut self, timeout: Duration) -> Self {
595        self.shutdown_timeout = timeout;
596        self
597    }
598
599    /// Override the MCP session idle timeout. Default: 20 minutes.
600    #[must_use]
601    pub fn with_session_idle_timeout(mut self, timeout: Duration) -> Self {
602        self.session_idle_timeout = timeout;
603        self
604    }
605
606    /// Override the SSE keep-alive interval. Default: 15 seconds.
607    #[must_use]
608    pub fn with_sse_keep_alive(mut self, interval: Duration) -> Self {
609        self.sse_keep_alive = interval;
610        self
611    }
612
613    /// Cap the global number of in-flight HTTP requests via
614    /// `tower::load_shed`. Excess requests receive 503 Service Unavailable.
615    /// Default: unlimited.
616    #[must_use]
617    pub fn with_max_concurrent_requests(mut self, limit: usize) -> Self {
618        self.max_concurrent_requests = Some(limit);
619        self
620    }
621
622    /// Override the per-handshake deadline on the TLS accept path.
623    /// Idle or slow-loris connections are dropped once it elapses.
624    /// Default: 10s. Must be greater than zero.
625    ///
626    /// Startup-only: the value is bound at listener construction and is
627    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
628    /// is configured via [`Self::with_tls`].
629    #[must_use]
630    pub fn with_tls_handshake_timeout(mut self, timeout: Duration) -> Self {
631        self.tls_handshake_timeout = timeout;
632        self
633    }
634
635    /// Override the cap on concurrently in-flight TLS handshakes. At
636    /// saturation the acceptor stops pulling new connections from the
637    /// kernel backlog (backpressure) instead of accepting and dropping.
638    /// Default: 256. Must be greater than zero.
639    ///
640    /// Startup-only: the value is bound at listener construction and is
641    /// NOT hot-reloadable via [`ReloadHandle`]. Has no effect unless TLS
642    /// is configured via [`Self::with_tls`].
643    #[must_use]
644    pub fn with_max_concurrent_tls_handshakes(mut self, limit: usize) -> Self {
645        self.max_concurrent_tls_handshakes = limit;
646        self
647    }
648
649    /// Cap tool invocations per source IP per minute. Enforced on every
650    /// `tools/call` request.
651    #[must_use]
652    pub fn with_tool_rate_limit(mut self, per_minute: u32) -> Self {
653        self.tool_rate_limit = Some(per_minute);
654        self
655    }
656
657    /// Register a callback that receives the [`ReloadHandle`] after the
658    /// server is built. Use it to wire SIGHUP-style hot reloads of API
659    /// keys and RBAC policy.
660    #[must_use]
661    pub fn with_reload_callback<F>(mut self, callback: F) -> Self
662    where
663        F: FnOnce(ReloadHandle) + Send + 'static,
664    {
665        self.on_reload_ready = Some(Box::new(callback));
666        self
667    }
668
669    /// Enable gzip/brotli response compression on MCP responses.
670    /// `min_size` is the smallest body size (bytes) eligible for
671    /// compression. Default min size: 1024.
672    #[must_use]
673    pub fn enable_compression(mut self, min_size: u16) -> Self {
674        self.compression_enabled = true;
675        self.compression_min_size = min_size;
676        self
677    }
678
679    /// Enable `/admin/*` diagnostic endpoints. Requires
680    /// [`with_auth`](Self::with_auth) to be set and enabled; otherwise
681    /// [`validate`](Self::validate) returns an error. `role` is the RBAC
682    /// role gate (default: `"admin"`).
683    #[must_use]
684    pub fn enable_admin(mut self, role: impl Into<String>) -> Self {
685        self.admin_enabled = true;
686        self.admin_role = role.into();
687        self
688    }
689
690    /// Log inbound HTTP request headers at DEBUG level. Sensitive
691    /// values remain redacted by the logging layer.
692    #[must_use]
693    pub fn enable_request_header_logging(mut self) -> Self {
694        self.log_request_headers = true;
695        self
696    }
697
698    /// Enable the Prometheus metrics listener on `bind` (e.g.
699    /// `127.0.0.1:9090`). Requires the `metrics` crate feature.
700    #[cfg(feature = "metrics")]
701    #[must_use]
702    pub fn with_metrics(mut self, bind: impl Into<String>) -> Self {
703        self.metrics_enabled = true;
704        self.metrics_bind = bind.into();
705        self
706    }
707
708    /// Validate the configuration and consume `self`, returning a
709    /// [`Validated<McpServerConfig>`] proof token required by [`serve`]
710    /// and [`serve_with_listener`]. This is the only way to construct
711    /// `Validated<McpServerConfig>`, so the type system guarantees
712    /// validation has run before the server starts.
713    ///
714    /// Checks:
715    ///
716    /// 1. `admin_enabled` requires `auth` to be configured and enabled.
717    /// 2. `tls_cert_path` and `tls_key_path` must both be set or both
718    ///    be unset.
719    /// 3. `bind_addr` must parse as a [`SocketAddr`].
720    /// 4. `public_url`, when set, must start with `http://` or `https://`.
721    /// 5. Each entry in `allowed_origins` must start with `http://` or
722    ///    `https://`.
723    /// 6. `max_request_body` must be greater than zero.
724    /// 7. When the `oauth` feature is enabled and an [`OAuthConfig`] is
725    ///    present, all OAuth URL fields (`jwks_uri`, `proxy.authorize_url`,
726    ///    `proxy.token_url`, `proxy.introspection_url`,
727    ///    `proxy.revocation_url`, `token_exchange.token_url`) must parse
728    ///    and use the `https` scheme. Set
729    ///    [`OAuthConfig::allow_http_oauth_urls`] to permit `http://`
730    ///    targets (strongly discouraged in production - see the field-level
731    ///    docs for the threat model).
732    ///
733    /// [`OAuthConfig`]: crate::oauth::OAuthConfig
734    /// [`OAuthConfig::allow_http_oauth_urls`]: crate::oauth::OAuthConfig::allow_http_oauth_urls
735    ///
736    /// # Errors
737    ///
738    /// Returns [`McpxError::Config`] with a human-readable message on
739    /// the first validation failure.
740    pub fn validate(self) -> Result<Validated<Self>, McpxError> {
741        self.check()?;
742        Ok(Validated(self))
743    }
744
745    /// Run the validation checks without consuming `self`. Used by
746    /// internal call sites (e.g. tests) that need to inspect a config
747    /// without taking ownership.
748    fn check(&self) -> Result<(), McpxError> {
749        // 1. admin <-> auth dependency. Mirrors the runtime check in
750        //    `build_app_router`: admin endpoints require an auth state,
751        //    which is built only when `auth` is `Some` *and* `enabled`.
752        if self.admin_enabled {
753            let auth_enabled = self.auth.as_ref().is_some_and(|a| a.enabled);
754            if !auth_enabled {
755                return Err(McpxError::Config(
756                    "admin_enabled=true requires auth to be configured and enabled".into(),
757                ));
758            }
759        }
760
761        // 2. TLS cert / key must be paired
762        match (&self.tls_cert_path, &self.tls_key_path) {
763            (Some(_), None) => {
764                return Err(McpxError::Config(
765                    "tls_cert_path is set but tls_key_path is missing".into(),
766                ));
767            }
768            (None, Some(_)) => {
769                return Err(McpxError::Config(
770                    "tls_key_path is set but tls_cert_path is missing".into(),
771                ));
772            }
773            _ => {}
774        }
775
776        // 3. bind_addr parses
777        if self.bind_addr.parse::<SocketAddr>().is_err() {
778            return Err(McpxError::Config(format!(
779                "bind_addr {:?} is not a valid socket address (expected e.g. 127.0.0.1:8080)",
780                self.bind_addr
781            )));
782        }
783
784        // 4. public_url scheme
785        if let Some(ref url) = self.public_url
786            && !(url.starts_with("http://") || url.starts_with("https://"))
787        {
788            return Err(McpxError::Config(format!(
789                "public_url {url:?} must start with http:// or https://"
790            )));
791        }
792
793        // 5. allowed_origins scheme
794        for origin in &self.allowed_origins {
795            if !(origin.starts_with("http://") || origin.starts_with("https://")) {
796                return Err(McpxError::Config(format!(
797                    "allowed_origins entry {origin:?} must start with http:// or https://"
798                )));
799            }
800        }
801
802        // 6. max_request_body > 0
803        if self.max_request_body == 0 {
804            return Err(McpxError::Config(
805                "max_request_body must be greater than zero".into(),
806            ));
807        }
808
809        // 7. OAuth URL fields enforce HTTPS (unless `allow_http_oauth_urls`)
810        #[cfg(feature = "oauth")]
811        if let Some(auth_cfg) = &self.auth
812            && let Some(oauth_cfg) = &auth_cfg.oauth
813        {
814            oauth_cfg.validate()?;
815        }
816
817        // 8. Security-header overrides parse as valid HTTP header values,
818        //    and HSTS does not smuggle in a `preload` directive.
819        validate_security_headers(&self.security_headers)?;
820
821        // 9. max_concurrent_requests must be > 0 when set. Zero would
822        //    deadlock the global concurrency limiter and reject every
823        //    request. Mirrors the TOML-side check in `src/config.rs`.
824        if let Some(0) = self.max_concurrent_requests {
825            return Err(McpxError::Config(
826                "max_concurrent_requests must be greater than zero when set".into(),
827            ));
828        }
829
830        // 10. Auth rate-limit `max_tracked_keys` must be > 0. A zero cap
831        //     would force `BoundedKeyedLimiter` to evict on every insert
832        //     and effectively disable rate limiting.
833        if let Some(auth_cfg) = &self.auth
834            && let Some(rl) = &auth_cfg.rate_limit
835            && rl.max_tracked_keys == 0
836        {
837            return Err(McpxError::Config(
838                "auth.rate_limit.max_tracked_keys must be greater than zero".into(),
839            ));
840        }
841
842        // 11. tls_handshake_timeout must be > 0. A zero deadline would
843        //     reap every handshake before it could complete, rejecting
844        //     all TLS connections. Mirrors the TOML-side check in
845        //     `src/config.rs`.
846        if self.tls_handshake_timeout == Duration::ZERO {
847            return Err(McpxError::Config(
848                "tls_handshake_timeout must be greater than zero".into(),
849            ));
850        }
851
852        // 12. max_concurrent_tls_handshakes must be > 0. A zero-permit
853        //     semaphore would never admit a handshake, deadlocking the
854        //     TLS accept path. Mirrors the TOML-side check in
855        //     `src/config.rs`.
856        if self.max_concurrent_tls_handshakes == 0 {
857            return Err(McpxError::Config(
858                "max_concurrent_tls_handshakes must be greater than zero".into(),
859            ));
860        }
861
862        Ok(())
863    }
864}
865
866/// Handle for hot-reloading server configuration without restart.
867///
868/// Obtained via [`McpServerConfig::on_reload_ready`].
869/// All swap operations are lock-free and wait-free -- in-flight requests
870/// finish with the old values while new requests see the update immediately.
871#[allow(
872    missing_debug_implementations,
873    reason = "contains Arc<AuthState> with non-Debug fields"
874)]
875pub struct ReloadHandle {
876    auth: Option<Arc<AuthState>>,
877    rbac: Option<Arc<ArcSwap<RbacPolicy>>>,
878    crl_set: Option<Arc<CrlSet>>,
879}
880
881impl ReloadHandle {
882    /// Atomically replace the API key list used by the auth middleware.
883    pub fn reload_auth_keys(&self, keys: Vec<crate::auth::ApiKeyEntry>) {
884        if let Some(ref auth) = self.auth {
885            auth.reload_keys(keys);
886        }
887    }
888
889    /// Atomically replace the RBAC policy used by the RBAC middleware.
890    pub fn reload_rbac(&self, policy: RbacPolicy) {
891        if let Some(ref rbac) = self.rbac {
892            rbac.store(Arc::new(policy));
893            tracing::info!("RBAC policy reloaded");
894        }
895    }
896
897    /// Force an immediate refresh of all cached mTLS CRLs.
898    ///
899    /// # Errors
900    ///
901    /// Returns an error if CRL refresh is unavailable or verifier rebuild fails.
902    pub async fn refresh_crls(&self) -> Result<(), McpxError> {
903        let Some(ref crl_set) = self.crl_set else {
904            return Err(McpxError::Config(
905                "CRL refresh requested but mTLS CRL support is not configured".into(),
906            ));
907        };
908
909        crl_set.force_refresh().await
910    }
911}
912
913/// Generic MCP HTTP server.
914///
915/// Wraps an axum server with `/healthz` and `/mcp` endpoints.
916/// When `tls_cert_path` and `tls_key_path` are both set, the server binds
917/// with TLS (rustls). Optionally supports mTLS client certificate auth.
918///
919/// # Errors
920///
921/// Returns an error if the TCP listener cannot bind, TLS config is invalid,
922/// or the server fails.
923// NOTE: cognitive complexity reduced from 111/25 to 83/25 by
924// extracting `run_server` (serve-loop tail) and `install_oauth_proxy_routes`.
925// Remaining flow is a linear router builder: middleware layering, feature-
926// gated auth/RBAC wiring, and PRM/metrics installation. Further extraction
927// would require threading many `&mut Router` helpers and hurt readability
928// of the layer order (which is security-relevant and must stay visible).
929#[allow(
930    clippy::too_many_lines,
931    clippy::cognitive_complexity,
932    reason = "middleware layer order is security-critical and must remain visible at one glance; extracting `&mut Router` helpers would obscure the auth/RBAC/origin/rate-limit ordering"
933)]
934/// Internal bundle of values produced by [`build_app_router`] and
935/// consumed by [`serve`] / [`serve_with_listener`] when driving the
936/// HTTP listener.
937struct AppRunParams {
938    /// TLS cert/key paths when TLS is configured.
939    tls_paths: Option<(PathBuf, PathBuf)>,
940    /// Per-handshake deadline on the TLS accept path.
941    tls_handshake_timeout: Duration,
942    /// Cap on concurrently in-flight TLS handshakes.
943    max_concurrent_tls_handshakes: usize,
944    /// mTLS configuration when mutual-TLS auth is enabled.
945    mtls_config: Option<MtlsConfig>,
946    /// Graceful shutdown drain window.
947    shutdown_timeout: Duration,
948    /// Shared auth state used by hot-reload callbacks.
949    auth_state: Option<Arc<AuthState>>,
950    /// Hot-reloadable RBAC state used by reload callbacks.
951    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
952    /// Optional callback that receives the final [`ReloadHandle`].
953    on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
954    /// Server-internal cancellation token. Cancelled by [`run_server`]
955    /// once the shutdown trigger fires (so rmcp's child token also
956    /// fires, terminating in-flight MCP sessions).
957    ct: CancellationToken,
958    /// `"http"` or `"https"` -- used only for boot-time logging.
959    scheme: &'static str,
960    /// Server name -- used only for boot-time logging.
961    name: String,
962}
963
964/// Build the full application axum [`axum::Router`] (MCP route +
965/// middleware stack + admin + OAuth + health endpoints + security
966/// headers + CORS + compression + concurrency limit + origin check)
967/// and the [`AppRunParams`] needed to drive it.
968///
969/// This is the shared core of [`serve`] and [`serve_with_listener`].
970/// It performs *no* network I/O: callers are responsible for binding
971/// (or accepting a pre-bound) [`TcpListener`] and invoking
972/// [`run_server`].
973#[allow(
974    clippy::cognitive_complexity,
975    reason = "router assembly is intrinsically sequential; splitting harms readability"
976)]
977#[allow(
978    deprecated,
979    reason = "internal router assembly reads deprecated `pub` config fields by design until 1.0 makes them pub(crate)"
980)]
981fn build_app_router<H, F>(
982    mut config: McpServerConfig,
983    handler_factory: F,
984) -> anyhow::Result<(axum::Router, AppRunParams)>
985where
986    H: ServerHandler + 'static,
987    F: Fn() -> H + Send + Sync + Clone + 'static,
988{
989    let ct = CancellationToken::new();
990
991    let allowed_hosts = derive_allowed_hosts(&config.bind_addr, config.public_url.as_deref());
992    tracing::info!(allowed_hosts = ?allowed_hosts, "configured Streamable HTTP allowed hosts");
993
994    let mcp_service = StreamableHttpService::new(
995        move || Ok(handler_factory()),
996        {
997            let mut mgr = LocalSessionManager::default();
998            mgr.session_config.keep_alive = Some(config.session_idle_timeout);
999            mgr.into()
1000        },
1001        StreamableHttpServerConfig::default()
1002            .with_allowed_hosts(allowed_hosts)
1003            .with_sse_keep_alive(Some(config.sse_keep_alive))
1004            .with_cancellation_token(ct.child_token()),
1005    );
1006
1007    // Build the MCP route, optionally wrapped with auth and RBAC middleware.
1008    let mut mcp_router = axum::Router::new().nest_service("/mcp", mcp_service);
1009
1010    // Build auth state eagerly when auth is configured so we can wire both
1011    // the auth middleware *and* the optional admin router against the same
1012    // state. The middleware itself is installed further down in layer order.
1013    let auth_state: Option<Arc<AuthState>> = match config.auth {
1014        Some(ref auth_config) if auth_config.enabled => {
1015            let rate_limiter = auth_config.rate_limit.as_ref().map(build_rate_limiter);
1016            let pre_auth_limiter = auth_config
1017                .rate_limit
1018                .as_ref()
1019                .map(crate::auth::build_pre_auth_limiter);
1020
1021            #[cfg(feature = "oauth")]
1022            let jwks_cache = auth_config
1023                .oauth
1024                .as_ref()
1025                .map(|c| crate::oauth::JwksCache::new(c).map(Arc::new))
1026                .transpose()
1027                .map_err(|e| std::io::Error::other(format!("JWKS HTTP client: {e}")))?;
1028
1029            Some(Arc::new(AuthState {
1030                api_keys: ArcSwap::new(Arc::new(auth_config.api_keys.clone())),
1031                rate_limiter,
1032                pre_auth_limiter,
1033                #[cfg(feature = "oauth")]
1034                jwks_cache,
1035                seen_identities: crate::auth::SeenIdentitySet::new(),
1036                counters: crate::auth::AuthCounters::default(),
1037            }))
1038        }
1039        _ => None,
1040    };
1041
1042    // Build the RBAC policy swap early so the admin router and the later
1043    // RBAC middleware layer share the same hot-reloadable state.
1044    let rbac_swap = Arc::new(ArcSwap::new(
1045        config
1046            .rbac
1047            .clone()
1048            .unwrap_or_else(|| Arc::new(RbacPolicy::disabled())),
1049    ));
1050
1051    // Optional /admin/* diagnostic routes. Merged BEFORE the
1052    // body-limit/timeout/RBAC/origin/auth layers so all of them apply.
1053    if config.admin_enabled {
1054        let Some(ref auth_state_ref) = auth_state else {
1055            return Err(anyhow::anyhow!(
1056                "admin_enabled=true requires auth to be configured and enabled"
1057            ));
1058        };
1059        let admin_state = crate::admin::AdminState {
1060            started_at: std::time::Instant::now(),
1061            name: config.name.clone(),
1062            version: config.version.clone(),
1063            auth: Some(Arc::clone(auth_state_ref)),
1064            rbac: Arc::clone(&rbac_swap),
1065        };
1066        let admin_cfg = crate::admin::AdminConfig {
1067            role: config.admin_role.clone(),
1068        };
1069        mcp_router = mcp_router.merge(crate::admin::admin_router(admin_state, &admin_cfg));
1070        tracing::info!(role = %config.admin_role, "/admin/* endpoints enabled");
1071    }
1072
1073    // ----- Middleware order (CRITICAL: read carefully) ------------------
1074    //
1075    // axum/tower applies layers **bottom-up** at runtime: the LAST layer
1076    // added is the OUTERMOST (runs first on a request). To achieve a
1077    // request-time flow of:
1078    //
1079    //   body-limit -> timeout -> auth -> rbac -> handler
1080    //
1081    // we add layers in the REVERSE order:
1082    //
1083    //   1. RBAC               (innermost, runs last before handler)
1084    //   2. auth               (parses identity, sets extension for RBAC)
1085    //   3. timeout            (bounds total request time)
1086    //   4. body-limit         (outermost on /mcp; caps payload before
1087    //                          anything else reads/buffers it)
1088    //
1089    // Origin validation is installed on the OUTER router (after the
1090    // /mcp router is merged in), so it also protects /healthz, /readyz,
1091    // /version, and any OAuth proxy endpoints.
1092    //
1093    // Rationale:
1094    // - Body-limit must be outermost on /mcp so RBAC (which reads the
1095    //   JSON-RPC body) cannot be DoS'd by a 100MB payload.
1096    // - Auth must run before RBAC because RBAC consumes
1097    //   `req.extensions().get::<AuthIdentity>()` to enforce per-role
1098    //   policy.
1099    // - Origin runs before auth so we reject cross-origin requests
1100    //   without spending Argon2 cycles on unauthenticated callers.
1101
1102    // [1] RBAC + tool rate-limit layer (innermost; closest to handler).
1103    // Always installed: even when RBAC is disabled, tool rate limiting may
1104    // be active (MCP spec: servers MUST rate limit tool invocations).
1105    {
1106        let tool_limiter: Option<Arc<ToolRateLimiter>> =
1107            config.tool_rate_limit.map(build_tool_rate_limiter);
1108
1109        if rbac_swap.load().is_enabled() {
1110            tracing::info!("RBAC enforcement enabled on /mcp");
1111        }
1112        if let Some(limit) = config.tool_rate_limit {
1113            tracing::info!(limit, "tool rate limiting enabled (calls/min per IP)");
1114        }
1115
1116        let rbac_for_mw = Arc::clone(&rbac_swap);
1117        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1118            let p = rbac_for_mw.load_full();
1119            let tl = tool_limiter.clone();
1120            rbac_middleware(p, tl, req, next)
1121        }));
1122    }
1123
1124    // [2] Auth layer (runs before RBAC so AuthIdentity is in extensions).
1125    if let Some(ref auth_config) = config.auth
1126        && auth_config.enabled
1127    {
1128        let Some(ref state) = auth_state else {
1129            return Err(anyhow::anyhow!("auth state missing despite enabled config"));
1130        };
1131
1132        let methods: Vec<&str> = [
1133            auth_config.mtls.is_some().then_some("mTLS"),
1134            (!auth_config.api_keys.is_empty()).then_some("bearer"),
1135            #[cfg(feature = "oauth")]
1136            auth_config.oauth.is_some().then_some("oauth-jwt"),
1137        ]
1138        .into_iter()
1139        .flatten()
1140        .collect();
1141
1142        tracing::info!(
1143            methods = %methods.join(", "),
1144            api_keys = auth_config.api_keys.len(),
1145            "auth enabled on /mcp"
1146        );
1147
1148        let state_for_mw = Arc::clone(state);
1149        mcp_router = mcp_router.layer(axum::middleware::from_fn(move |req, next| {
1150            let s = Arc::clone(&state_for_mw);
1151            auth_middleware(s, req, next)
1152        }));
1153    }
1154
1155    // [3] Request timeout (returns 408 on expiry). Bounds total request
1156    // duration including auth + handler.
1157    mcp_router = mcp_router.layer(tower_http::timeout::TimeoutLayer::with_status_code(
1158        axum::http::StatusCode::REQUEST_TIMEOUT,
1159        config.request_timeout,
1160    ));
1161
1162    // [4] Request body size limit (OUTERMOST on /mcp). Prevents OOM /
1163    // DoS from oversized payloads BEFORE any inner layer (auth, RBAC)
1164    // attempts to buffer or parse the body.
1165    mcp_router = mcp_router.layer(tower_http::limit::RequestBodyLimitLayer::new(
1166        config.max_request_body,
1167    ));
1168
1169    // Compute the effective allowed-origins list for the outer
1170    // origin-check layer (installed on the merged router below). When
1171    // `allowed_origins` is empty but `public_url` is set, auto-derive
1172    // the origin from the public URL so MCP clients (e.g. Claude Code)
1173    // that send `Origin: <server-url>` are accepted without explicit
1174    // config.
1175    let mut effective_origins = config.allowed_origins.clone();
1176    if effective_origins.is_empty()
1177        && let Some(ref url) = config.public_url
1178    {
1179        // Origin = scheme + "://" + host (+ ":" + port if non-default).
1180        // Strip any path/query from the public URL. Offsets come from
1181        // `find`, so they are char-boundary-aligned; `get(..)` keeps that
1182        // machine-checked (a violation degrades to an empty slice).
1183        if let Some(scheme_end) = url.find("://") {
1184            let scheme_with_sep = url.get(..scheme_end + 3).unwrap_or_default();
1185            let after_scheme = url.get(scheme_end + 3..).unwrap_or_default();
1186            let host_end = after_scheme.find('/').unwrap_or(after_scheme.len());
1187            let host = after_scheme.get(..host_end).unwrap_or_default();
1188            let origin = format!("{scheme_with_sep}{host}");
1189            tracing::info!(
1190                %origin,
1191                "auto-derived allowed origin from public_url"
1192            );
1193            effective_origins.push(origin);
1194        }
1195    }
1196    let allowed_origins: Arc<[String]> = Arc::from(effective_origins);
1197    let cors_origins = Arc::clone(&allowed_origins);
1198    let log_request_headers = config.log_request_headers;
1199
1200    let readyz_route = if let Some(check) = config.readiness_check.take() {
1201        axum::routing::get(move || readyz(Arc::clone(&check)))
1202    } else {
1203        axum::routing::get(healthz)
1204    };
1205
1206    #[allow(unused_mut)] // mut needed when oauth feature adds PRM route
1207    let mut router = axum::Router::new()
1208        .route("/healthz", axum::routing::get(healthz))
1209        .route("/readyz", readyz_route)
1210        .route(
1211            "/version",
1212            axum::routing::get({
1213                // Pre-serialize the version payload once at router-build
1214                // time. The handler then serves a cheap `Arc::clone` of the
1215                // immutable bytes per request, avoiding `serde_json::Value`
1216                // allocation + serialization on every `/version` hit.
1217                let payload_bytes: Arc<[u8]> =
1218                    serialize_version_payload(&config.name, &config.version);
1219                move || {
1220                    let p = Arc::clone(&payload_bytes);
1221                    async move {
1222                        (
1223                            [(axum::http::header::CONTENT_TYPE, "application/json")],
1224                            p.to_vec(),
1225                        )
1226                    }
1227                }
1228            }),
1229        )
1230        .merge(mcp_router);
1231
1232    // Merge application-specific routes (bypass MCP auth/RBAC middleware).
1233    if let Some(extra) = config.extra_router.take() {
1234        router = router.merge(extra);
1235    }
1236
1237    // RFC 9728: Protected Resource Metadata endpoint.
1238    // When OAuth is configured, serve full metadata with authorization_servers.
1239    // Otherwise, serve a minimal document with just the resource URL and no
1240    // authorization_servers -- this tells MCP clients (e.g. Claude Code SDK)
1241    // that the server exists but does NOT require OAuth authentication,
1242    // preventing them from gating the connection behind a broken auth flow.
1243    let server_url = if let Some(ref url) = config.public_url {
1244        url.trim_end_matches('/').to_owned()
1245    } else {
1246        let prm_scheme = if config.tls_cert_path.is_some() {
1247            "https"
1248        } else {
1249            "http"
1250        };
1251        format!("{prm_scheme}://{}", config.bind_addr)
1252    };
1253    let resource_url = format!("{server_url}/mcp");
1254
1255    #[cfg(feature = "oauth")]
1256    let prm_metadata = if let Some(ref auth_config) = config.auth
1257        && let Some(ref oauth_config) = auth_config.oauth
1258    {
1259        crate::oauth::protected_resource_metadata(&resource_url, &server_url, oauth_config)
1260    } else {
1261        serde_json::json!({ "resource": resource_url })
1262    };
1263    #[cfg(not(feature = "oauth"))]
1264    let prm_metadata = serde_json::json!({ "resource": resource_url });
1265
1266    router = router.route(
1267        "/.well-known/oauth-protected-resource",
1268        axum::routing::get(move || {
1269            let m = prm_metadata.clone();
1270            async move { axum::Json(m) }
1271        }),
1272    );
1273
1274    // OAuth 2.1 proxy endpoints: when an OAuth proxy is configured, expose
1275    // /authorize, /token, /register, and authorization server metadata so
1276    // MCP clients can perform Authorization Code + PKCE against the upstream
1277    // IdP (e.g. Keycloak) transparently.
1278    #[cfg(feature = "oauth")]
1279    if let Some(ref auth_config) = config.auth
1280        && let Some(ref oauth_config) = auth_config.oauth
1281        && oauth_config.proxy.is_some()
1282    {
1283        router =
1284            install_oauth_proxy_routes(router, &server_url, oauth_config, auth_state.as_ref())?;
1285    }
1286
1287    // OWASP security response headers (applied to all responses).
1288    // HSTS is conditional on TLS being configured.
1289    let is_tls = config.tls_cert_path.is_some();
1290    let security_headers_cfg = Arc::new(config.security_headers.clone());
1291    router = router.layer(axum::middleware::from_fn(move |req, next| {
1292        let cfg = Arc::clone(&security_headers_cfg);
1293        security_headers_middleware(is_tls, cfg, req, next)
1294    }));
1295
1296    // CORS preflight layer (required for browser-based MCP clients).
1297    // Uses the same effective origins as the origin check middleware
1298    // (including auto-derived origin from public_url).
1299    if !cors_origins.is_empty() {
1300        let cors = tower_http::cors::CorsLayer::new()
1301            .allow_origin(
1302                cors_origins
1303                    .iter()
1304                    .filter_map(|o| o.parse::<axum::http::HeaderValue>().ok())
1305                    .collect::<Vec<_>>(),
1306            )
1307            .allow_methods([
1308                axum::http::Method::GET,
1309                axum::http::Method::POST,
1310                axum::http::Method::OPTIONS,
1311            ])
1312            .allow_headers([
1313                axum::http::header::CONTENT_TYPE,
1314                axum::http::header::AUTHORIZATION,
1315            ]);
1316        router = router.layer(cors);
1317    }
1318
1319    // Optional response compression (gzip + brotli). Skips small bodies
1320    // to avoid overhead. Applied after CORS so preflight responses remain
1321    // uncompressed.
1322    if config.compression_enabled {
1323        use tower_http::compression::Predicate as _;
1324        let predicate = tower_http::compression::DefaultPredicate::new().and(
1325            tower_http::compression::predicate::SizeAbove::new(config.compression_min_size),
1326        );
1327        router = router.layer(
1328            tower_http::compression::CompressionLayer::new()
1329                .gzip(true)
1330                .br(true)
1331                .compress_when(predicate),
1332        );
1333        tracing::info!(
1334            min_size = config.compression_min_size,
1335            "response compression enabled (gzip, br)"
1336        );
1337    }
1338
1339    // Optional global concurrency cap. `load_shed` converts the
1340    // `ConcurrencyLimit` back-pressure error into 503 instead of hanging.
1341    if let Some(max) = config.max_concurrent_requests {
1342        let overload_handler = tower::ServiceBuilder::new()
1343            .layer(axum::error_handling::HandleErrorLayer::new(
1344                |_err: tower::BoxError| async {
1345                    (
1346                        axum::http::StatusCode::SERVICE_UNAVAILABLE,
1347                        axum::Json(serde_json::json!({
1348                            "error": "overloaded",
1349                            "error_description": "server is at capacity, retry later"
1350                        })),
1351                    )
1352                },
1353            ))
1354            .layer(tower::load_shed::LoadShedLayer::new())
1355            .layer(tower::limit::ConcurrencyLimitLayer::new(max));
1356        router = router.layer(overload_handler);
1357        tracing::info!(max, "global concurrency limit enabled");
1358    }
1359
1360    // JSON fallback for unmatched routes. Without this, axum returns
1361    // an empty-body 404 that breaks MCP clients (e.g. Claude Code SDK)
1362    // when they probe OAuth endpoints like /authorize or /token.
1363    router = router.fallback(|| async {
1364        (
1365            axum::http::StatusCode::NOT_FOUND,
1366            axum::Json(serde_json::json!({
1367                "error": "not_found",
1368                "error_description": "The requested endpoint does not exist"
1369            })),
1370        )
1371    });
1372
1373    // Prometheus metrics: recording middleware + separate listener.
1374    #[cfg(feature = "metrics")]
1375    if config.metrics_enabled {
1376        let metrics = Arc::new(
1377            crate::metrics::McpMetrics::new().map_err(|e| anyhow::anyhow!("metrics init: {e}"))?,
1378        );
1379        let m = Arc::clone(&metrics);
1380        router = router.layer(axum::middleware::from_fn(
1381            move |req: Request<Body>, next: Next| {
1382                let m = Arc::clone(&m);
1383                metrics_middleware(m, req, next)
1384            },
1385        ));
1386        let metrics_bind = config.metrics_bind.clone();
1387        let metrics_shutdown = ct.clone();
1388        tokio::spawn(async move {
1389            if let Err(e) =
1390                crate::metrics::serve_metrics(metrics_bind, metrics, metrics_shutdown).await
1391            {
1392                tracing::error!("metrics listener failed: {e}");
1393            }
1394        });
1395    }
1396
1397    // Origin validation layer (MCP spec: servers MUST validate the
1398    // Origin header to prevent DNS rebinding attacks). Installed as the
1399    // OUTERMOST layer on the OUTER router so it protects ALL routes
1400    // (`/mcp`, `/healthz`, `/readyz`, `/version`, OAuth proxy endpoints,
1401    // admin endpoints, extra_router, etc.) and runs BEFORE auth so we
1402    // reject cross-origin attackers without spending Argon2 cycles.
1403    //
1404    // Origin-less requests (e.g. server-to-server probes, curl, native
1405    // MCP clients) are permitted; only requests with an Origin header
1406    // that does not match `effective_origins` are rejected.
1407    router = router.layer(axum::middleware::from_fn(move |req, next| {
1408        let origins = Arc::clone(&allowed_origins);
1409        origin_check_middleware(origins, log_request_headers, req, next)
1410    }));
1411
1412    let scheme = if config.tls_cert_path.is_some() {
1413        "https"
1414    } else {
1415        "http"
1416    };
1417
1418    let tls_paths = match (&config.tls_cert_path, &config.tls_key_path) {
1419        (Some(cert), Some(key)) => Some((cert.clone(), key.clone())),
1420        _ => None,
1421    };
1422    let tls_handshake_timeout = config.tls_handshake_timeout;
1423    let max_concurrent_tls_handshakes = config.max_concurrent_tls_handshakes;
1424    let mtls_config = config.auth.as_ref().and_then(|a| a.mtls.as_ref()).cloned();
1425
1426    Ok((
1427        router,
1428        AppRunParams {
1429            tls_paths,
1430            tls_handshake_timeout,
1431            max_concurrent_tls_handshakes,
1432            mtls_config,
1433            shutdown_timeout: config.shutdown_timeout,
1434            auth_state,
1435            rbac_swap,
1436            on_reload_ready: config.on_reload_ready.take(),
1437            ct,
1438            scheme,
1439            name: config.name.clone(),
1440        },
1441    ))
1442}
1443
1444/// Run the MCP HTTP server, binding to `config.bind_addr` and serving
1445/// until an OS shutdown signal (Ctrl-C / SIGTERM) is received.
1446///
1447/// This is the standard entry point for production deployments. For
1448/// deterministic shutdown control (e.g. integration tests), see
1449/// [`serve_with_listener`].
1450///
1451/// The configuration must be validated first via
1452/// [`McpServerConfig::validate`], which returns a [`Validated`] proof
1453/// token. This typestate guarantees, at compile time, that the server
1454/// never starts with an invalid configuration.
1455///
1456/// # Errors
1457///
1458/// Returns [`McpxError::Startup`] if binding to `config.bind_addr`
1459/// fails, or if the underlying axum server returns an error.
1460pub async fn serve<H, F>(
1461    config: Validated<McpServerConfig>,
1462    handler_factory: F,
1463) -> Result<(), McpxError>
1464where
1465    H: ServerHandler + 'static,
1466    F: Fn() -> H + Send + Sync + Clone + 'static,
1467{
1468    let config = config.into_inner();
1469    #[allow(
1470        deprecated,
1471        reason = "internal serve() reads `bind_addr` to construct the listener; field becomes pub(crate) in 1.0"
1472    )]
1473    let bind_addr = config.bind_addr.clone();
1474    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1475
1476    let listener = TcpListener::bind(&bind_addr)
1477        .await
1478        .map_err(|e| io_to_startup(&format!("bind {bind_addr}"), e))?;
1479    log_listening(&params.name, params.scheme, &bind_addr);
1480
1481    run_server(
1482        router,
1483        listener,
1484        params.tls_paths,
1485        params.tls_handshake_timeout,
1486        params.max_concurrent_tls_handshakes,
1487        params.mtls_config,
1488        params.shutdown_timeout,
1489        params.auth_state,
1490        params.rbac_swap,
1491        params.on_reload_ready,
1492        params.ct,
1493    )
1494    .await
1495    .map_err(anyhow_to_startup)
1496}
1497
1498/// Run the MCP HTTP server on a pre-bound [`TcpListener`], with optional
1499/// readiness signalling and external shutdown control.
1500///
1501/// This variant is intended for **deterministic integration tests** and
1502/// for embedders that need to bind the listening socket themselves
1503/// (e.g. systemd socket activation). Compared to [`serve`]:
1504///
1505/// * The caller passes a `TcpListener` that is already bound. This
1506///   eliminates the bind race in tests that previously required
1507///   poll-the-`/healthz`-loop start-up detection.
1508/// * `ready_tx`, when `Some`, receives the socket's
1509///   [`SocketAddr`] *after* the router is built and immediately before
1510///   the server starts accepting connections. Tests can `await` the
1511///   matching `oneshot::Receiver` to know exactly when it is safe to
1512///   issue requests.
1513/// * `shutdown`, when `Some`, gives the caller a
1514///   [`CancellationToken`] that triggers the same graceful-shutdown
1515///   path as a real OS signal. This avoids cross-platform issues with
1516///   sending real `SIGTERM` from tests on Windows.
1517///
1518/// All three optional parameters degrade gracefully: if `ready_tx` is
1519/// `None`, no signal is sent; if `shutdown` is `None`, the server only
1520/// stops on an OS signal (just like [`serve`]).
1521///
1522/// # Errors
1523///
1524/// Returns [`McpxError::Startup`] if router construction fails, if reading
1525/// the listener's `local_addr()` fails, or if the underlying axum
1526/// server returns an error.
1527pub async fn serve_with_listener<H, F>(
1528    listener: TcpListener,
1529    config: Validated<McpServerConfig>,
1530    handler_factory: F,
1531    ready_tx: Option<tokio::sync::oneshot::Sender<SocketAddr>>,
1532    shutdown: Option<CancellationToken>,
1533) -> Result<(), McpxError>
1534where
1535    H: ServerHandler + 'static,
1536    F: Fn() -> H + Send + Sync + Clone + 'static,
1537{
1538    let config = config.into_inner();
1539    let local_addr = listener
1540        .local_addr()
1541        .map_err(|e| io_to_startup("listener.local_addr", e))?;
1542    let (router, params) = build_app_router(config, handler_factory).map_err(anyhow_to_startup)?;
1543
1544    log_listening(&params.name, params.scheme, &local_addr.to_string());
1545
1546    // Forward external shutdown into the server-internal cancellation
1547    // token so `run_server`'s shutdown trigger picks it up alongside
1548    // any real OS signal.
1549    if let Some(external) = shutdown {
1550        let internal = params.ct.clone();
1551        tokio::spawn(async move {
1552            external.cancelled().await;
1553            internal.cancel();
1554        });
1555    }
1556
1557    // Signal readiness *after* the router is fully built and external
1558    // shutdown is wired, but *before* run_server takes ownership of
1559    // the listener. The receiver can immediately issue requests.
1560    if let Some(tx) = ready_tx {
1561        // Receiver may have been dropped (test gave up). That's fine.
1562        let _ = tx.send(local_addr);
1563    }
1564
1565    run_server(
1566        router,
1567        listener,
1568        params.tls_paths,
1569        params.tls_handshake_timeout,
1570        params.max_concurrent_tls_handshakes,
1571        params.mtls_config,
1572        params.shutdown_timeout,
1573        params.auth_state,
1574        params.rbac_swap,
1575        params.on_reload_ready,
1576        params.ct,
1577    )
1578    .await
1579    .map_err(anyhow_to_startup)
1580}
1581
1582/// Emit the standard "listening on …" log lines used by both
1583/// [`serve`] and [`serve_with_listener`].
1584#[allow(
1585    clippy::cognitive_complexity,
1586    reason = "tracing::info! macro expansions inflate the score; logic is trivial"
1587)]
1588fn log_listening(name: &str, scheme: &str, addr: &str) {
1589    tracing::info!("{name} listening on {addr}");
1590    tracing::info!("  MCP endpoint: {scheme}://{addr}/mcp");
1591    tracing::info!("  Health check: {scheme}://{addr}/healthz");
1592    tracing::info!("  Readiness:   {scheme}://{addr}/readyz");
1593}
1594
1595/// Drive the chosen axum server variant (TLS or plain) with a graceful
1596/// shutdown window. Consumes the router and listener.
1597///
1598/// # Shutdown semantics
1599///
1600/// A single shutdown trigger (the FIRST of: OS signal via
1601/// `shutdown_signal()`, or external cancellation of `ct`) starts BOTH:
1602///
1603/// 1. axum's `.with_graceful_shutdown(...)` future, which stops
1604///    accepting new connections and waits for in-flight requests to
1605///    drain;
1606/// 2. a `tokio::time::sleep(shutdown_timeout)` race that forces exit if
1607///    drainage exceeds `shutdown_timeout`.
1608///
1609/// Previously this function awaited `shutdown_signal()` independently
1610/// in BOTH branches of a `tokio::select!`. Because `shutdown_signal`
1611/// resolves once per future and consumes one signal, the force-exit
1612/// timer was tied to a SECOND signal (a second SIGTERM the operator
1613/// would never send). Under a single SIGTERM the graceful drain could
1614/// hang indefinitely. The current implementation derives both branches
1615/// from a single shared trigger so the timeout race is anchored to the
1616/// FIRST (and only) signal.
1617#[allow(
1618    clippy::too_many_arguments,
1619    clippy::cognitive_complexity,
1620    reason = "server start-up threads TLS, reload state, and graceful shutdown through one flow"
1621)]
1622async fn run_server(
1623    router: axum::Router,
1624    listener: TcpListener,
1625    tls_paths: Option<(PathBuf, PathBuf)>,
1626    tls_handshake_timeout: Duration,
1627    max_concurrent_tls_handshakes: usize,
1628    mtls_config: Option<MtlsConfig>,
1629    shutdown_timeout: Duration,
1630    auth_state: Option<Arc<AuthState>>,
1631    rbac_swap: Arc<ArcSwap<RbacPolicy>>,
1632    mut on_reload_ready: Option<Box<dyn FnOnce(ReloadHandle) + Send>>,
1633    ct: CancellationToken,
1634) -> anyhow::Result<()> {
1635    // `shutdown_trigger` fires when the FIRST source resolves: either
1636    // an OS signal (Ctrl-C / SIGTERM) or external cancellation of `ct`
1637    // (which the test harness uses for deterministic shutdown).
1638    let shutdown_trigger = CancellationToken::new();
1639    {
1640        let trigger = shutdown_trigger.clone();
1641        let parent = ct.clone();
1642        tokio::spawn(async move {
1643            tokio::select! {
1644                () = shutdown_signal() => {}
1645                () = parent.cancelled() => {}
1646            }
1647            trigger.cancel();
1648        });
1649    }
1650
1651    let graceful = {
1652        let trigger = shutdown_trigger.clone();
1653        let ct = ct.clone();
1654        async move {
1655            trigger.cancelled().await;
1656            tracing::info!("shutting down (grace period: {shutdown_timeout:?})");
1657            ct.cancel();
1658        }
1659    };
1660
1661    let force_exit_timer = {
1662        let trigger = shutdown_trigger.clone();
1663        async move {
1664            trigger.cancelled().await;
1665            tokio::time::sleep(shutdown_timeout).await;
1666        }
1667    };
1668
1669    if let Some((cert_path, key_path)) = tls_paths {
1670        let crl_set = if let Some(mtls) = mtls_config.as_ref()
1671            && mtls.crl_enabled
1672        {
1673            let (ca_certs, roots) = load_client_auth_roots(&mtls.ca_cert_path)?;
1674            let (crl_set, discover_rx) =
1675                mtls_revocation::bootstrap_fetch(roots, &ca_certs, mtls.clone())
1676                    .await
1677                    .map_err(|error| anyhow::anyhow!(error.to_string()))?;
1678            tokio::spawn(mtls_revocation::run_crl_refresher(
1679                Arc::clone(&crl_set),
1680                discover_rx,
1681                ct.clone(),
1682            ));
1683            Some(crl_set)
1684        } else {
1685            None
1686        };
1687
1688        if let Some(cb) = on_reload_ready.take() {
1689            cb(ReloadHandle {
1690                auth: auth_state.clone(),
1691                rbac: Some(Arc::clone(&rbac_swap)),
1692                crl_set: crl_set.clone(),
1693            });
1694        }
1695
1696        let tls_listener = TlsListener::new(
1697            listener,
1698            &cert_path,
1699            &key_path,
1700            mtls_config.as_ref(),
1701            crl_set,
1702            tls_handshake_timeout,
1703            max_concurrent_tls_handshakes,
1704        )?;
1705        let make_svc = router.into_make_service_with_connect_info::<TlsConnInfo>();
1706        tokio::select! {
1707            result = axum::serve(tls_listener, make_svc)
1708                .with_graceful_shutdown(graceful) => { result?; }
1709            () = force_exit_timer => {
1710                tracing::warn!("shutdown timeout exceeded, forcing exit");
1711            }
1712        }
1713    } else {
1714        if let Some(cb) = on_reload_ready.take() {
1715            cb(ReloadHandle {
1716                auth: auth_state,
1717                rbac: Some(rbac_swap),
1718                crl_set: None,
1719            });
1720        }
1721
1722        let make_svc = router.into_make_service_with_connect_info::<SocketAddr>();
1723        tokio::select! {
1724            result = axum::serve(listener, make_svc)
1725                .with_graceful_shutdown(graceful) => { result?; }
1726            () = force_exit_timer => {
1727                tracing::warn!("shutdown timeout exceeded, forcing exit");
1728            }
1729        }
1730    }
1731
1732    Ok(())
1733}
1734
1735/// Install the OAuth 2.1 proxy endpoints (`/authorize`, `/token`,
1736/// `/register`, and authorization server metadata) on `router`. The
1737/// caller must ensure `oauth_config.proxy` is `Some`.
1738///
1739/// # Errors
1740///
1741/// Returns [`McpxError::Startup`] if the shared
1742/// [`crate::oauth::OauthHttpClient`] cannot be initialized.
1743#[cfg(feature = "oauth")]
1744fn install_oauth_proxy_routes(
1745    router: axum::Router,
1746    server_url: &str,
1747    oauth_config: &crate::oauth::OAuthConfig,
1748    auth_state: Option<&Arc<AuthState>>,
1749) -> Result<axum::Router, McpxError> {
1750    let Some(ref proxy) = oauth_config.proxy else {
1751        return Ok(router);
1752    };
1753
1754    // Single shared HTTP client for all proxy endpoints. Cloning is
1755    // cheap (refcounted) and shares the underlying connection pool.
1756    let http = crate::oauth::OauthHttpClient::with_config(oauth_config)?;
1757
1758    let asm = crate::oauth::authorization_server_metadata(server_url, oauth_config);
1759    let router = router.route(
1760        "/.well-known/oauth-authorization-server",
1761        axum::routing::get(move || {
1762            let m = asm.clone();
1763            async move { axum::Json(m) }
1764        }),
1765    );
1766
1767    let proxy_authorize = proxy.clone();
1768    let router = router.route(
1769        "/authorize",
1770        axum::routing::get(
1771            move |axum::extract::RawQuery(query): axum::extract::RawQuery| {
1772                let p = proxy_authorize.clone();
1773                async move { crate::oauth::handle_authorize(&p, &query.unwrap_or_default()) }
1774            },
1775        ),
1776    );
1777
1778    let proxy_token = proxy.clone();
1779    let token_http = http.clone();
1780    let router = router.route(
1781        "/token",
1782        axum::routing::post(move |body: String| {
1783            let p = proxy_token.clone();
1784            let h = token_http.clone();
1785            async move { crate::oauth::handle_token(&h, &p, &body).await }
1786        })
1787        .layer(axum::middleware::from_fn(
1788            oauth_token_cache_headers_middleware,
1789        )),
1790    );
1791
1792    let proxy_register = proxy.clone();
1793    let router = router.route(
1794        "/register",
1795        axum::routing::post(move |axum::Json(body): axum::Json<serde_json::Value>| {
1796            let p = proxy_register;
1797            async move { axum::Json(crate::oauth::handle_register(&p, &body)) }
1798        })
1799        .layer(axum::middleware::from_fn(
1800            oauth_token_cache_headers_middleware,
1801        )),
1802    );
1803
1804    let admin_routes_enabled = proxy.expose_admin_endpoints
1805        && (proxy.introspection_url.is_some() || proxy.revocation_url.is_some());
1806    if proxy.expose_admin_endpoints
1807        && !proxy.require_auth_on_admin_endpoints
1808        && proxy.allow_unauthenticated_admin_endpoints
1809    {
1810        // M3 escape-hatch in effect: validate() let this through because
1811        // the operator explicitly opted in. Surface it loudly at startup
1812        // so the choice is auditable in logs.
1813        tracing::warn!(
1814            "OAuth introspect/revoke endpoints are unauthenticated by explicit \
1815             allow_unauthenticated_admin_endpoints opt-out; ensure an \
1816             authenticated reverse proxy fronts these routes"
1817        );
1818    }
1819
1820    let admin_router = if admin_routes_enabled {
1821        build_oauth_admin_router(proxy, http, auth_state)?
1822    } else {
1823        axum::Router::new()
1824    };
1825
1826    let router = router.merge(admin_router);
1827
1828    tracing::info!(
1829        introspect = proxy.expose_admin_endpoints && proxy.introspection_url.is_some(),
1830        revoke = proxy.expose_admin_endpoints && proxy.revocation_url.is_some(),
1831        "OAuth 2.1 proxy endpoints enabled (/authorize, /token, /register)"
1832    );
1833    Ok(router)
1834}
1835
1836/// Build the optional `/introspect` + `/revoke` admin sub-router.
1837///
1838/// Layered with [`oauth_token_cache_headers_middleware`] so RFC 6749 §5.1
1839/// / RFC 6750 §5.4 cache headers are emitted, and conditionally with the
1840/// auth middleware when `proxy.require_auth_on_admin_endpoints` is set.
1841#[cfg(feature = "oauth")]
1842fn build_oauth_admin_router(
1843    proxy: &crate::oauth::OAuthProxyConfig,
1844    http: crate::oauth::OauthHttpClient,
1845    auth_state: Option<&Arc<AuthState>>,
1846) -> Result<axum::Router, McpxError> {
1847    let mut admin_router = axum::Router::new();
1848    if proxy.introspection_url.is_some() {
1849        let proxy_introspect = proxy.clone();
1850        let introspect_http = http.clone();
1851        admin_router = admin_router.route(
1852            "/introspect",
1853            axum::routing::post(move |body: String| {
1854                let p = proxy_introspect.clone();
1855                let h = introspect_http.clone();
1856                async move { crate::oauth::handle_introspect(&h, &p, &body).await }
1857            }),
1858        );
1859    }
1860    if proxy.revocation_url.is_some() {
1861        let proxy_revoke = proxy.clone();
1862        let revoke_http = http;
1863        admin_router = admin_router.route(
1864            "/revoke",
1865            axum::routing::post(move |body: String| {
1866                let p = proxy_revoke.clone();
1867                let h = revoke_http.clone();
1868                async move { crate::oauth::handle_revoke(&h, &p, &body).await }
1869            }),
1870        );
1871    }
1872
1873    let admin_router = admin_router.layer(axum::middleware::from_fn(
1874        oauth_token_cache_headers_middleware,
1875    ));
1876
1877    if proxy.require_auth_on_admin_endpoints {
1878        let Some(state) = auth_state else {
1879            return Err(McpxError::Startup(
1880                "oauth proxy admin endpoints require auth state".into(),
1881            ));
1882        };
1883        let state_for_mw = Arc::clone(state);
1884        Ok(
1885            admin_router.layer(axum::middleware::from_fn(move |req, next| {
1886                let s = Arc::clone(&state_for_mw);
1887                auth_middleware(s, req, next)
1888            })),
1889        )
1890    } else {
1891        Ok(admin_router)
1892    }
1893}
1894
1895/// Build the host allow-list for rmcp's DNS rebinding protection.
1896///
1897/// Includes loopback hosts by default, then augments with host/authority
1898/// derived from `public_url` and the server bind address.
1899fn derive_allowed_hosts(bind_addr: &str, public_url: Option<&str>) -> Vec<String> {
1900    let mut hosts = vec![
1901        "localhost".to_owned(),
1902        "127.0.0.1".to_owned(),
1903        "::1".to_owned(),
1904    ];
1905
1906    if let Some(url) = public_url
1907        && let Ok(uri) = url.parse::<axum::http::Uri>()
1908        && let Some(authority) = uri.authority()
1909    {
1910        let host = authority.host().to_owned();
1911        if !hosts.iter().any(|h| h == &host) {
1912            hosts.push(host);
1913        }
1914
1915        let authority = authority.as_str().to_owned();
1916        if !hosts.iter().any(|h| h == &authority) {
1917            hosts.push(authority);
1918        }
1919    }
1920
1921    if let Ok(uri) = format!("http://{bind_addr}").parse::<axum::http::Uri>()
1922        && let Some(authority) = uri.authority()
1923    {
1924        let host = authority.host().to_owned();
1925        if !hosts.iter().any(|h| h == &host) {
1926            hosts.push(host);
1927        }
1928
1929        let authority = authority.as_str().to_owned();
1930        if !hosts.iter().any(|h| h == &authority) {
1931            hosts.push(authority);
1932        }
1933    }
1934
1935    hosts
1936}
1937
1938// - TLS support -
1939
1940/// Implement axum's `Connected` trait for `TlsConnInfo` so that
1941/// `ConnectInfo<TlsConnInfo>` is available in middleware when serving
1942/// over our custom `TlsListener`.
1943///
1944/// The identity is read directly from the wrapping
1945/// [`AuthenticatedTlsStream`], which guarantees one-to-one correspondence
1946/// between the TLS connection and its mTLS identity. This eliminates the
1947/// previous shared-map approach which was vulnerable to ephemeral-port
1948/// reuse races (an unauthenticated reconnection from the same `(IP, port)`
1949/// pair could alias a stale entry).
1950impl axum::extract::connect_info::Connected<axum::serve::IncomingStream<'_, TlsListener>>
1951    for TlsConnInfo
1952{
1953    fn connect_info(target: axum::serve::IncomingStream<'_, TlsListener>) -> Self {
1954        let addr = *target.remote_addr();
1955        let identity = target.io().identity().cloned();
1956        TlsConnInfo::new(addr, identity)
1957    }
1958}
1959
1960/// Default per-handshake deadline on the TLS accept path. Prevents idle
1961/// or slow-loris connections from pinning handshake worker tasks (and
1962/// their semaphore permits) indefinitely.
1963///
1964/// Configurable since 1.9.0 via
1965/// [`McpServerConfig::with_tls_handshake_timeout`].
1966const DEFAULT_TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
1967
1968/// Default upper bound on concurrently in-flight TLS handshakes. When
1969/// saturated, the acceptor task stops pulling new connections from the
1970/// kernel backlog (backpressure) instead of accepting and dropping them
1971/// in user space.
1972///
1973/// Configurable since 1.9.0 via
1974/// [`McpServerConfig::with_max_concurrent_tls_handshakes`].
1975const DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES: usize = 256;
1976
1977/// Capacity of the completed-handshake queue between the acceptor task and
1978/// `axum::serve`'s `accept()` loop. Handshake workers block on `send` when
1979/// the queue is full, so a slow accept loop back-pressures handshakes
1980/// rather than buffering completed connections unboundedly.
1981const TLS_ACCEPT_CHANNEL_CAPACITY: usize = 32;
1982
1983/// A TLS-wrapping listener that implements axum's `Listener` trait.
1984///
1985/// TCP accepts and TLS handshakes run on a dedicated background task: each
1986/// accepted connection's handshake is spawned onto its own worker task,
1987/// bounded by a configurable concurrent-handshake cap (default
1988/// [`DEFAULT_MAX_CONCURRENT_TLS_HANDSHAKES`]) and a per-handshake timeout
1989/// (default [`DEFAULT_TLS_HANDSHAKE_TIMEOUT`]). A slow or idle client
1990/// therefore cannot stall other connections behind a serialized inline
1991/// handshake.
1992///
1993/// When mTLS is configured, client certificates are verified against the
1994/// configured CA and the client identity is extracted at handshake time.
1995/// The extracted identity is bound to the connection itself via the
1996/// returned [`AuthenticatedTlsStream`], so it is impossible for an
1997/// unrelated connection to observe it.
1998struct TlsListener {
1999    /// Bound address, captured eagerly before the `TcpListener` moves into
2000    /// the acceptor task.
2001    local_addr: SocketAddr,
2002    /// Completed handshakes produced by the acceptor task's workers.
2003    rx: mpsc::Receiver<(AuthenticatedTlsStream, SocketAddr)>,
2004    /// Background task driving TCP accepts and concurrent TLS handshakes.
2005    /// Aborted on drop so the listener releases its port deterministically.
2006    acceptor_task: tokio::task::JoinHandle<()>,
2007}
2008
2009impl TlsListener {
2010    fn new(
2011        inner: TcpListener,
2012        cert_path: &Path,
2013        key_path: &Path,
2014        mtls_config: Option<&MtlsConfig>,
2015        crl_set: Option<Arc<CrlSet>>,
2016        handshake_timeout: Duration,
2017        max_concurrent_handshakes: usize,
2018    ) -> anyhow::Result<Self> {
2019        // Install the ring crypto provider (ok to call multiple times).
2020        rustls::crypto::ring::default_provider()
2021            .install_default()
2022            .ok();
2023
2024        let certs = load_certs(cert_path)?;
2025        let key = load_key(key_path)?;
2026
2027        let mtls_default_role;
2028
2029        let tls_config = if let Some(mtls) = mtls_config {
2030            mtls_default_role = mtls.default_role.clone();
2031            let verifier: Arc<dyn rustls::server::danger::ClientCertVerifier> = if mtls.crl_enabled
2032            {
2033                let Some(crl_set) = crl_set else {
2034                    return Err(anyhow::anyhow!(
2035                        "mTLS CRL verifier requested but CRL state was not initialized"
2036                    ));
2037                };
2038                Arc::new(DynamicClientCertVerifier::new(crl_set))
2039            } else {
2040                let (_, root_store) = load_client_auth_roots(&mtls.ca_cert_path)?;
2041                if mtls.required {
2042                    rustls::server::WebPkiClientVerifier::builder(root_store)
2043                        .build()
2044                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2045                } else {
2046                    rustls::server::WebPkiClientVerifier::builder(root_store)
2047                        .allow_unauthenticated()
2048                        .build()
2049                        .map_err(|e| anyhow::anyhow!("mTLS verifier error: {e}"))?
2050                }
2051            };
2052
2053            tracing::info!(
2054                ca = %mtls.ca_cert_path.display(),
2055                required = mtls.required,
2056                crl_enabled = mtls.crl_enabled,
2057                "mTLS client auth configured"
2058            );
2059
2060            rustls::ServerConfig::builder_with_protocol_versions(&[
2061                &rustls::version::TLS12,
2062                &rustls::version::TLS13,
2063            ])
2064            .with_client_cert_verifier(verifier)
2065            .with_single_cert(certs, key)?
2066        } else {
2067            mtls_default_role = "viewer".to_owned();
2068            rustls::ServerConfig::builder_with_protocol_versions(&[
2069                &rustls::version::TLS12,
2070                &rustls::version::TLS13,
2071            ])
2072            .with_no_client_auth()
2073            .with_single_cert(certs, key)?
2074        };
2075
2076        let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
2077        tracing::info!(
2078            "TLS enabled (cert: {}, key: {})",
2079            cert_path.display(),
2080            key_path.display()
2081        );
2082        let local_addr = inner.local_addr()?;
2083        let (tx, rx) = mpsc::channel(TLS_ACCEPT_CHANNEL_CAPACITY);
2084        let acceptor_task = tokio::spawn(run_tls_acceptor(
2085            inner,
2086            acceptor,
2087            mtls_default_role,
2088            tx,
2089            handshake_timeout,
2090            max_concurrent_handshakes,
2091        ));
2092        Ok(Self {
2093            local_addr,
2094            rx,
2095            acceptor_task,
2096        })
2097    }
2098
2099    /// Extract the mTLS client cert identity from a completed TLS handshake.
2100    /// Returns `None` if no client certificate was presented or if the
2101    /// certificate could not be parsed into an [`AuthIdentity`].
2102    fn extract_handshake_identity(
2103        tls_stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2104        default_role: &str,
2105        addr: SocketAddr,
2106    ) -> Option<AuthIdentity> {
2107        let (_, server_conn) = tls_stream.get_ref();
2108        let cert_der = server_conn.peer_certificates()?.first()?;
2109        let id = extract_mtls_identity(cert_der.as_ref(), default_role)?;
2110        tracing::debug!(name = %id.name, peer = %addr, "mTLS client cert accepted");
2111        Some(id)
2112    }
2113}
2114
2115/// Drive TCP accepts and concurrent TLS handshakes for [`TlsListener`].
2116///
2117/// Each accepted connection's handshake runs on its own worker task under
2118/// a permit from a `max_concurrent_handshakes`-sized semaphore and a
2119/// `handshake_timeout` deadline. Completed handshakes are pushed to `tx`;
2120/// failures and timeouts are logged at DEBUG and the connection dropped.
2121/// The loop exits when the owning [`TlsListener`] is dropped.
2122async fn run_tls_acceptor(
2123    listener: TcpListener,
2124    acceptor: tokio_rustls::TlsAcceptor,
2125    default_role: String,
2126    tx: mpsc::Sender<(AuthenticatedTlsStream, SocketAddr)>,
2127    handshake_timeout: Duration,
2128    max_concurrent_handshakes: usize,
2129) {
2130    let inflight = Arc::new(Semaphore::new(max_concurrent_handshakes));
2131    loop {
2132        // Acquire the permit BEFORE accepting: at saturation, pending
2133        // connections wait in the kernel backlog instead of being accepted
2134        // and then buffered or dropped in user space.
2135        let Ok(permit) = Arc::clone(&inflight).acquire_owned().await else {
2136            // The semaphore is never closed; defensive exit.
2137            return;
2138        };
2139        let (stream, addr) = match listener.accept().await {
2140            Ok(pair) => pair,
2141            Err(e) => {
2142                tracing::debug!("TCP accept error: {e}");
2143                continue;
2144            }
2145        };
2146        if tx.is_closed() {
2147            // The listener was dropped (shutdown): stop accepting.
2148            return;
2149        }
2150        let acceptor = acceptor.clone();
2151        let default_role = default_role.clone();
2152        let tx = tx.clone();
2153        tokio::spawn(async move {
2154            let _permit = permit;
2155            match tokio::time::timeout(handshake_timeout, acceptor.accept(stream)).await {
2156                Ok(Ok(tls_stream)) => {
2157                    let identity =
2158                        TlsListener::extract_handshake_identity(&tls_stream, &default_role, addr);
2159                    let wrapped = AuthenticatedTlsStream {
2160                        inner: tls_stream,
2161                        identity,
2162                    };
2163                    // The receiver only disappears during shutdown; discard
2164                    // the completed connection quietly rather than logging.
2165                    let _ = tx.send((wrapped, addr)).await;
2166                }
2167                Ok(Err(e)) => {
2168                    tracing::debug!("TLS handshake failed from {addr}: {e}");
2169                }
2170                Err(_elapsed) => {
2171                    tracing::debug!(
2172                        "TLS handshake timed out from {addr} after {handshake_timeout:?}"
2173                    );
2174                }
2175            }
2176        });
2177    }
2178}
2179
2180/// A TLS stream paired with the mTLS identity extracted at handshake time.
2181///
2182/// Wraps [`tokio_rustls::server::TlsStream`] so the verified client
2183/// identity travels with the connection itself. This replaces the previous
2184/// shared `MtlsIdentities` map, eliminating the
2185/// `(SocketAddr) -> AuthIdentity` aliasing risk caused by ephemeral-port
2186/// reuse and removing the need for an LRU eviction policy.
2187///
2188/// The wrapper is `Unpin` (its inner stream is `Unpin` because
2189/// [`tokio::net::TcpStream`] is `Unpin`), so `AsyncRead`/`AsyncWrite`
2190/// delegation uses safe pin projection via `Pin::new(&mut self.inner)`.
2191pub(crate) struct AuthenticatedTlsStream {
2192    inner: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
2193    identity: Option<AuthIdentity>,
2194}
2195
2196impl AuthenticatedTlsStream {
2197    /// Returns the verified mTLS client identity, if any.
2198    #[must_use]
2199    pub(crate) const fn identity(&self) -> Option<&AuthIdentity> {
2200        self.identity.as_ref()
2201    }
2202}
2203
2204impl std::fmt::Debug for AuthenticatedTlsStream {
2205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2206        f.debug_struct("AuthenticatedTlsStream")
2207            .field("identity", &self.identity.as_ref().map(|id| &id.name))
2208            .finish_non_exhaustive()
2209    }
2210}
2211
2212impl tokio::io::AsyncRead for AuthenticatedTlsStream {
2213    fn poll_read(
2214        mut self: Pin<&mut Self>,
2215        cx: &mut std::task::Context<'_>,
2216        buf: &mut tokio::io::ReadBuf<'_>,
2217    ) -> std::task::Poll<std::io::Result<()>> {
2218        Pin::new(&mut self.inner).poll_read(cx, buf)
2219    }
2220}
2221
2222impl tokio::io::AsyncWrite for AuthenticatedTlsStream {
2223    fn poll_write(
2224        mut self: Pin<&mut Self>,
2225        cx: &mut std::task::Context<'_>,
2226        buf: &[u8],
2227    ) -> std::task::Poll<std::io::Result<usize>> {
2228        Pin::new(&mut self.inner).poll_write(cx, buf)
2229    }
2230
2231    fn poll_flush(
2232        mut self: Pin<&mut Self>,
2233        cx: &mut std::task::Context<'_>,
2234    ) -> std::task::Poll<std::io::Result<()>> {
2235        Pin::new(&mut self.inner).poll_flush(cx)
2236    }
2237
2238    fn poll_shutdown(
2239        mut self: Pin<&mut Self>,
2240        cx: &mut std::task::Context<'_>,
2241    ) -> std::task::Poll<std::io::Result<()>> {
2242        Pin::new(&mut self.inner).poll_shutdown(cx)
2243    }
2244
2245    fn poll_write_vectored(
2246        mut self: Pin<&mut Self>,
2247        cx: &mut std::task::Context<'_>,
2248        bufs: &[std::io::IoSlice<'_>],
2249    ) -> std::task::Poll<std::io::Result<usize>> {
2250        Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
2251    }
2252
2253    fn is_write_vectored(&self) -> bool {
2254        self.inner.is_write_vectored()
2255    }
2256}
2257
2258impl axum::serve::Listener for TlsListener {
2259    type Io = AuthenticatedTlsStream;
2260    type Addr = SocketAddr;
2261
2262    /// Yield the next fully-handshaken TLS connection.
2263    ///
2264    /// Cancel-safe: this is a plain `mpsc::Receiver::recv`, so cancelling
2265    /// the future (axum selects it against graceful shutdown) never loses
2266    /// a connection.
2267    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
2268        if let Some(pair) = self.rx.recv().await {
2269            return pair;
2270        }
2271        // The channel only closes if the acceptor task terminated, which
2272        // means the TcpListener is gone and the OS already refuses new
2273        // connections. `Listener::accept` is infallible and panicking is
2274        // forbidden, so park forever: existing connections keep being
2275        // served and graceful shutdown still completes.
2276        tracing::error!("TLS acceptor task terminated; no further connections will be accepted");
2277        std::future::pending().await
2278    }
2279
2280    fn local_addr(&self) -> std::io::Result<Self::Addr> {
2281        Ok(self.local_addr)
2282    }
2283}
2284
2285impl Drop for TlsListener {
2286    fn drop(&mut self) {
2287        // Stop accepting immediately and release the bound port. In-flight
2288        // handshake workers notice the closed channel and exit quietly.
2289        self.acceptor_task.abort();
2290    }
2291}
2292
2293fn load_certs(path: &Path) -> anyhow::Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
2294    use rustls::pki_types::pem::PemObject;
2295    let certs: Vec<_> = rustls::pki_types::CertificateDer::pem_file_iter(path)
2296        .map_err(|e| anyhow::anyhow!("failed to read certs from {}: {e}", path.display()))?
2297        .collect::<Result<_, _>>()
2298        .map_err(|e| anyhow::anyhow!("invalid cert in {}: {e}", path.display()))?;
2299    anyhow::ensure!(
2300        !certs.is_empty(),
2301        "no certificates found in {}",
2302        path.display()
2303    );
2304    Ok(certs)
2305}
2306
2307fn load_client_auth_roots(
2308    path: &Path,
2309) -> anyhow::Result<(
2310    Vec<rustls::pki_types::CertificateDer<'static>>,
2311    Arc<RootCertStore>,
2312)> {
2313    let ca_certs = load_certs(path)?;
2314    let mut root_store = RootCertStore::empty();
2315    for cert in &ca_certs {
2316        root_store
2317            .add(cert.clone())
2318            .map_err(|error| anyhow::anyhow!("invalid CA cert: {error}"))?;
2319    }
2320
2321    Ok((ca_certs, Arc::new(root_store)))
2322}
2323
2324fn load_key(path: &Path) -> anyhow::Result<rustls::pki_types::PrivateKeyDer<'static>> {
2325    use rustls::pki_types::pem::PemObject;
2326    rustls::pki_types::PrivateKeyDer::from_pem_file(path)
2327        .map_err(|e| anyhow::anyhow!("failed to read key from {}: {e}", path.display()))
2328}
2329
2330#[allow(
2331    clippy::unused_async,
2332    reason = "axum route handler signature requires `async fn` even when the body is synchronous"
2333)]
2334async fn healthz() -> impl IntoResponse {
2335    axum::Json(serde_json::json!({
2336        "status": "ok",
2337    }))
2338}
2339
2340/// Build the `/version` JSON payload for a given server name and version.
2341///
2342/// Build metadata (`build_git_sha`, `build_timestamp`, `rust_version`) is
2343/// read at compile time from the `RMCP_SERVER_KIT_BUILD_SHA`,
2344/// `RMCP_SERVER_KIT_BUILD_TIME`, and `RMCP_SERVER_KIT_RUSTC_VERSION` env
2345/// vars. Unset values resolve to `"unknown"`.
2346fn version_payload(name: &str, version: &str) -> serde_json::Value {
2347    serde_json::json!({
2348        "name": name,
2349        "version": version,
2350        "build_git_sha": option_env!("RMCP_SERVER_KIT_BUILD_SHA").unwrap_or("unknown"),
2351        "build_timestamp": option_env!("RMCP_SERVER_KIT_BUILD_TIME").unwrap_or("unknown"),
2352        "rust_version": option_env!("RMCP_SERVER_KIT_RUSTC_VERSION").unwrap_or("unknown"),
2353        "mcpx_version": env!("CARGO_PKG_VERSION"),
2354    })
2355}
2356
2357/// Pre-serialize the `/version` payload to immutable bytes.
2358///
2359/// This is called once at router-build time so per-request handling can
2360/// reuse a cheap `Arc<[u8]>` clone instead of re-serializing a
2361/// [`serde_json::Value`] on every hit.
2362///
2363/// Serialization of a flat `serde_json::Value` of static-string fields
2364/// cannot fail in practice; the fallback to `b"{}"` exists only to
2365/// satisfy the crate-wide `unwrap_used` / `expect_used` lint policy.
2366fn serialize_version_payload(name: &str, version: &str) -> Arc<[u8]> {
2367    let value = version_payload(name, version);
2368    serde_json::to_vec(&value).map_or_else(|_| Arc::from(&b"{}"[..]), Arc::from)
2369}
2370
2371async fn readyz(check: ReadinessCheck) -> impl IntoResponse {
2372    let status = check().await;
2373    let ready = status
2374        .get("ready")
2375        .and_then(serde_json::Value::as_bool)
2376        .unwrap_or(false);
2377    let code = if ready {
2378        axum::http::StatusCode::OK
2379    } else {
2380        axum::http::StatusCode::SERVICE_UNAVAILABLE
2381    };
2382    (code, axum::Json(status))
2383}
2384
2385/// Wait for SIGINT (ctrl-c) or SIGTERM (container stop).
2386///
2387/// On non-Unix platforms, only SIGINT is handled.
2388async fn shutdown_signal() {
2389    let ctrl_c = tokio::signal::ctrl_c();
2390
2391    #[cfg(unix)]
2392    {
2393        match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
2394            Ok(mut term) => {
2395                tokio::select! {
2396                    _ = ctrl_c => {}
2397                    _ = term.recv() => {}
2398                }
2399            }
2400            Err(e) => {
2401                tracing::warn!(error = %e, "failed to register SIGTERM handler, using SIGINT only");
2402                ctrl_c.await.ok();
2403            }
2404        }
2405    }
2406
2407    #[cfg(not(unix))]
2408    {
2409        ctrl_c.await.ok();
2410    }
2411}
2412
2413// -- Origin validation (MCP 2025-11-25 spec, section 2.0.1) --
2414
2415/// Middleware that validates the `Origin` header on incoming HTTP requests.
2416///
2417/// Record HTTP request metrics (method, path, status, duration).
2418#[cfg(feature = "metrics")]
2419async fn metrics_middleware(
2420    metrics: Arc<crate::metrics::McpMetrics>,
2421    req: Request<Body>,
2422    next: Next,
2423) -> axum::response::Response {
2424    let method = req.method().to_string();
2425    let path = req.uri().path().to_owned();
2426    let start = std::time::Instant::now();
2427
2428    let response = next.run(req).await;
2429
2430    let status = response.status().as_u16().to_string();
2431    let duration = start.elapsed().as_secs_f64();
2432
2433    metrics
2434        .http_requests_total
2435        .with_label_values(&[&method, &path, &status])
2436        .inc();
2437    metrics
2438        .http_request_duration_seconds
2439        .with_label_values(&[&method, &path])
2440        .observe(duration);
2441
2442    response
2443}
2444
2445/// OWASP security header hardening applied to every response.
2446///
2447/// Sets: `X-Content-Type-Options`, `X-Frame-Options`, `Cache-Control`,
2448/// `Referrer-Policy`, `Cross-Origin-Opener-Policy`, `Cross-Origin-Resource-Policy`,
2449/// `Cross-Origin-Embedder-Policy`, `Permissions-Policy`,
2450/// `X-Permitted-Cross-Domain-Policies`, `Content-Security-Policy`,
2451/// `X-DNS-Prefetch-Control`, and (when TLS is active) `Strict-Transport-Security`.
2452///
2453/// Each header's value can be customised via [`SecurityHeadersConfig`]
2454/// on [`McpServerConfig`]. See that type for the three-state semantic
2455/// (`None` = default, `Some("")` = omit, `Some(v)` = override).
2456async fn security_headers_middleware(
2457    is_tls: bool,
2458    cfg: Arc<SecurityHeadersConfig>,
2459    req: Request<Body>,
2460    next: Next,
2461) -> axum::response::Response {
2462    use axum::http::{HeaderName, header};
2463
2464    let mut resp = next.run(req).await;
2465    let headers = resp.headers_mut();
2466
2467    // Strip server identity headers to reduce information leakage.
2468    headers.remove(header::SERVER);
2469    headers.remove(HeaderName::from_static("x-powered-by"));
2470
2471    apply_security_header(
2472        headers,
2473        header::X_CONTENT_TYPE_OPTIONS,
2474        cfg.x_content_type_options.as_deref(),
2475        "nosniff",
2476    );
2477    apply_security_header(
2478        headers,
2479        header::X_FRAME_OPTIONS,
2480        cfg.x_frame_options.as_deref(),
2481        "deny",
2482    );
2483    apply_security_header(
2484        headers,
2485        header::CACHE_CONTROL,
2486        cfg.cache_control.as_deref(),
2487        "no-store, max-age=0",
2488    );
2489    apply_security_header(
2490        headers,
2491        header::REFERRER_POLICY,
2492        cfg.referrer_policy.as_deref(),
2493        "no-referrer",
2494    );
2495    apply_security_header(
2496        headers,
2497        HeaderName::from_static("cross-origin-opener-policy"),
2498        cfg.cross_origin_opener_policy.as_deref(),
2499        "same-origin",
2500    );
2501    apply_security_header(
2502        headers,
2503        HeaderName::from_static("cross-origin-resource-policy"),
2504        cfg.cross_origin_resource_policy.as_deref(),
2505        "same-origin",
2506    );
2507    apply_security_header(
2508        headers,
2509        HeaderName::from_static("cross-origin-embedder-policy"),
2510        cfg.cross_origin_embedder_policy.as_deref(),
2511        "require-corp",
2512    );
2513    apply_security_header(
2514        headers,
2515        HeaderName::from_static("permissions-policy"),
2516        cfg.permissions_policy.as_deref(),
2517        "accelerometer=(), camera=(), geolocation=(), microphone=()",
2518    );
2519    apply_security_header(
2520        headers,
2521        HeaderName::from_static("x-permitted-cross-domain-policies"),
2522        cfg.x_permitted_cross_domain_policies.as_deref(),
2523        "none",
2524    );
2525    apply_security_header(
2526        headers,
2527        HeaderName::from_static("content-security-policy"),
2528        cfg.content_security_policy.as_deref(),
2529        "default-src 'none'; frame-ancestors 'none'",
2530    );
2531    apply_security_header(
2532        headers,
2533        HeaderName::from_static("x-dns-prefetch-control"),
2534        cfg.x_dns_prefetch_control.as_deref(),
2535        "off",
2536    );
2537
2538    if is_tls {
2539        apply_security_header(
2540            headers,
2541            header::STRICT_TRANSPORT_SECURITY,
2542            cfg.strict_transport_security.as_deref(),
2543            "max-age=63072000; includeSubDomains",
2544        );
2545    }
2546
2547    resp
2548}
2549
2550/// Set a single security header on the response, honouring the
2551/// three-state override semantic (None = default, Some("") = omit,
2552/// Some(value) = override).
2553///
2554/// Defence-in-depth: if an override value somehow reaches this point
2555/// despite [`validate_security_headers`] having approved it (e.g. a
2556/// runtime mutation on a non-`Validated` field), we log at error level
2557/// and fall back to the static default rather than panicking. The
2558/// `Validated<McpServerConfig>` type makes that path unreachable in
2559/// well-typed code paths.
2560fn apply_security_header(
2561    headers: &mut axum::http::HeaderMap,
2562    name: axum::http::HeaderName,
2563    override_value: Option<&str>,
2564    default: &'static str,
2565) {
2566    use axum::http::HeaderValue;
2567
2568    match override_value {
2569        None => {
2570            headers.insert(name, HeaderValue::from_static(default));
2571        }
2572        Some("") => {
2573            // Operator explicitly opted out of this header.
2574        }
2575        Some(v) => match HeaderValue::from_str(v) {
2576            Ok(hv) => {
2577                headers.insert(name, hv);
2578            }
2579            Err(err) => {
2580                tracing::error!(
2581                    header = %name,
2582                    error = %err,
2583                    "invalid security header override reached middleware; using default"
2584                );
2585                headers.insert(name, HeaderValue::from_static(default));
2586            }
2587        },
2588    }
2589}
2590
2591/// Validate every non-empty entry in a [`SecurityHeadersConfig`].
2592///
2593/// - `None` and `Some("")` are accepted unconditionally (use-default and
2594///   omit, respectively).
2595/// - `Some(v)` is rejected if `axum::http::HeaderValue::from_str(v)` fails.
2596/// - `strict_transport_security` additionally rejects any value
2597///   containing `preload` (case-insensitive). Operators who genuinely
2598///   want to commit to the HSTS preload list must do so via a future
2599///   explicit `with_hsts_preload(true)` builder, not by smuggling
2600///   `preload` through this knob.
2601fn validate_security_headers(cfg: &SecurityHeadersConfig) -> Result<(), McpxError> {
2602    use axum::http::HeaderValue;
2603
2604    let fields: &[(&str, Option<&str>)] = &[
2605        (
2606            "x_content_type_options",
2607            cfg.x_content_type_options.as_deref(),
2608        ),
2609        ("x_frame_options", cfg.x_frame_options.as_deref()),
2610        ("cache_control", cfg.cache_control.as_deref()),
2611        ("referrer_policy", cfg.referrer_policy.as_deref()),
2612        (
2613            "cross_origin_opener_policy",
2614            cfg.cross_origin_opener_policy.as_deref(),
2615        ),
2616        (
2617            "cross_origin_resource_policy",
2618            cfg.cross_origin_resource_policy.as_deref(),
2619        ),
2620        (
2621            "cross_origin_embedder_policy",
2622            cfg.cross_origin_embedder_policy.as_deref(),
2623        ),
2624        ("permissions_policy", cfg.permissions_policy.as_deref()),
2625        (
2626            "x_permitted_cross_domain_policies",
2627            cfg.x_permitted_cross_domain_policies.as_deref(),
2628        ),
2629        (
2630            "content_security_policy",
2631            cfg.content_security_policy.as_deref(),
2632        ),
2633        (
2634            "x_dns_prefetch_control",
2635            cfg.x_dns_prefetch_control.as_deref(),
2636        ),
2637        (
2638            "strict_transport_security",
2639            cfg.strict_transport_security.as_deref(),
2640        ),
2641    ];
2642
2643    for (field, value) in fields {
2644        let Some(v) = value else { continue };
2645        if v.is_empty() {
2646            continue;
2647        }
2648        if let Err(err) = HeaderValue::from_str(v) {
2649            return Err(McpxError::Config(format!(
2650                "invalid security_headers.{field}: {err}"
2651            )));
2652        }
2653    }
2654
2655    if let Some(v) = cfg.strict_transport_security.as_deref()
2656        && !v.is_empty()
2657        && v.to_ascii_lowercase().contains("preload")
2658    {
2659        return Err(McpxError::Config(format!(
2660            "invalid security_headers.strict_transport_security: {v:?} contains the `preload` directive; \
2661             HSTS preload must be opted into explicitly via a dedicated builder, not via this knob"
2662        )));
2663    }
2664
2665    Ok(())
2666}
2667
2668/// Append RFC 6749 §5.1 / RFC 6750 §5.4 cache and `Vary` headers required
2669/// on OAuth token-issuing responses.
2670///
2671/// `Cache-Control: no-store, max-age=0` is already applied globally by
2672/// [`security_headers_middleware`]; this middleware adds:
2673///
2674/// - `Pragma: no-cache` -- mandated by RFC 6749 §5.1 for HTTP/1.0 caches.
2675/// - `Vary: Authorization` -- mandated by RFC 6750 §5.4 for endpoints
2676///   whose response depends on the `Authorization` header.
2677///
2678/// Applied only to the OAuth proxy token-class endpoints (`/token`,
2679/// `/register`, `/introspect`, `/revoke`). `Vary` is appended (not
2680/// inserted) so any `Vary` value already present (e.g. `Accept-Encoding`
2681/// from a compression layer, or `Origin` from a CORS layer) is preserved.
2682#[cfg(feature = "oauth")]
2683async fn oauth_token_cache_headers_middleware(
2684    req: Request<Body>,
2685    next: Next,
2686) -> axum::response::Response {
2687    use axum::http::{HeaderValue, header};
2688
2689    let mut resp = next.run(req).await;
2690    let headers = resp.headers_mut();
2691    headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
2692    headers.append(header::VARY, HeaderValue::from_static("Authorization"));
2693    resp
2694}
2695
2696/// Per the MCP spec: if the Origin header is present and its value is not in
2697/// the allowed list, respond with 403 Forbidden. Requests without an Origin
2698/// header are allowed through (e.g. non-browser clients like curl, SDKs).
2699async fn origin_check_middleware(
2700    allowed: Arc<[String]>,
2701    log_request_headers: bool,
2702    req: Request<Body>,
2703    next: Next,
2704) -> axum::response::Response {
2705    let method = req.method().clone();
2706    let path = req.uri().path().to_owned();
2707
2708    log_incoming_request(&method, &path, req.headers(), log_request_headers);
2709
2710    if let Some(origin) = req.headers().get(axum::http::header::ORIGIN) {
2711        let origin_str = origin.to_str().unwrap_or("");
2712        if !allowed.iter().any(|a| a == origin_str) {
2713            tracing::warn!(
2714                origin = origin_str,
2715                %method,
2716                %path,
2717                allowed = ?&*allowed,
2718                "rejected request: Origin not allowed"
2719            );
2720            return (
2721                axum::http::StatusCode::FORBIDDEN,
2722                "Forbidden: Origin not allowed",
2723            )
2724                .into_response();
2725        }
2726    }
2727    next.run(req).await
2728}
2729
2730/// Emit a DEBUG log for an incoming request, optionally including the full
2731/// (redacted) header set.
2732fn log_incoming_request(
2733    method: &axum::http::Method,
2734    path: &str,
2735    headers: &axum::http::HeaderMap,
2736    log_request_headers: bool,
2737) {
2738    if log_request_headers {
2739        tracing::debug!(
2740            %method,
2741            %path,
2742            headers = %format_request_headers_for_log(headers),
2743            "incoming request"
2744        );
2745    } else {
2746        tracing::debug!(%method, %path, "incoming request");
2747    }
2748}
2749
2750fn format_request_headers_for_log(headers: &axum::http::HeaderMap) -> String {
2751    headers
2752        .iter()
2753        .map(|(k, v)| {
2754            let name = k.as_str();
2755            if name == "authorization" || name == "cookie" || name == "proxy-authorization" {
2756                format!("{name}: [REDACTED]")
2757            } else {
2758                format!("{name}: {}", v.to_str().unwrap_or("<non-utf8>"))
2759            }
2760        })
2761        .collect::<Vec<_>>()
2762        .join(", ")
2763}
2764
2765// -- stdio transport --
2766
2767/// Serve an MCP server over stdin/stdout (stdio transport).
2768///
2769/// # Security warnings
2770///
2771/// - **No authentication**: the parent process has full, unrestricted access.
2772/// - **No RBAC**: all tools are available regardless of policy.
2773/// - **No TLS**: messages travel over OS pipes in plaintext.
2774/// - **Single client**: only the parent process can connect.
2775/// - **No Origin validation**: not applicable to stdio.
2776///
2777/// Use this only when the MCP client spawns the server as a trusted subprocess
2778/// (e.g. Claude Desktop, VS Code Copilot). For network-accessible deployments,
2779/// use `serve()` (Streamable HTTP) instead.
2780///
2781/// # Errors
2782///
2783/// Returns [`McpxError::Startup`] if the handler fails to initialize or the
2784/// transport disconnects unexpectedly.
2785// NOTE: reported complexity 32/25 is driven entirely by `tracing::*!`
2786// macro expansion in this 18-line function (info/warn/info + two matches).
2787// There is nothing meaningful to extract; the allow stays.
2788#[allow(
2789    clippy::cognitive_complexity,
2790    reason = "complexity is purely tracing macro expansion (info/warn + match arms); 18 lines of straight-line code, nothing meaningful to extract"
2791)]
2792pub async fn serve_stdio<H>(handler: H) -> Result<(), McpxError>
2793where
2794    H: ServerHandler + 'static,
2795{
2796    use rmcp::ServiceExt as _;
2797
2798    tracing::info!("stdio transport: serving on stdin/stdout");
2799    tracing::warn!("stdio mode: auth, RBAC, TLS, and Origin checks are DISABLED");
2800
2801    let transport = rmcp::transport::io::stdio();
2802
2803    let service = handler
2804        .serve(transport)
2805        .await
2806        .map_err(|e| McpxError::Startup(format!("stdio initialize failed: {e}")))?;
2807
2808    if let Err(e) = service.waiting().await {
2809        tracing::warn!(error = %e, "stdio session ended with error");
2810    }
2811    tracing::info!("stdio session ended");
2812    Ok(())
2813}
2814
2815#[cfg(test)]
2816mod tests {
2817    #![allow(
2818        clippy::unwrap_used,
2819        clippy::expect_used,
2820        clippy::panic,
2821        clippy::indexing_slicing,
2822        clippy::unwrap_in_result,
2823        clippy::print_stdout,
2824        clippy::print_stderr,
2825        deprecated,
2826        reason = "internal unit tests legitimately read/write the deprecated `pub` fields they were designed to verify"
2827    )]
2828    use std::{sync::Arc, time::Duration};
2829
2830    use axum::{
2831        body::Body,
2832        http::{Request, StatusCode, header},
2833        response::IntoResponse,
2834    };
2835    use http_body_util::BodyExt;
2836    use tower::ServiceExt as _;
2837
2838    use super::*;
2839
2840    // -- McpServerConfig --
2841
2842    #[test]
2843    fn server_config_new_defaults() {
2844        let cfg = McpServerConfig::new("0.0.0.0:8443", "test-server", "1.0.0");
2845        assert_eq!(cfg.bind_addr, "0.0.0.0:8443");
2846        assert_eq!(cfg.name, "test-server");
2847        assert_eq!(cfg.version, "1.0.0");
2848        assert!(cfg.tls_cert_path.is_none());
2849        assert!(cfg.tls_key_path.is_none());
2850        assert!(cfg.auth.is_none());
2851        assert!(cfg.rbac.is_none());
2852        assert!(cfg.allowed_origins.is_empty());
2853        assert!(cfg.tool_rate_limit.is_none());
2854        assert!(cfg.readiness_check.is_none());
2855        assert_eq!(cfg.max_request_body, 1024 * 1024);
2856        assert_eq!(cfg.request_timeout, Duration::from_mins(2));
2857        assert_eq!(cfg.shutdown_timeout, Duration::from_secs(30));
2858        assert!(!cfg.log_request_headers);
2859        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(10));
2860        assert_eq!(cfg.max_concurrent_tls_handshakes, 256);
2861    }
2862
2863    #[test]
2864    fn tls_handshake_builders_set_fields() {
2865        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
2866            .with_tls_handshake_timeout(Duration::from_secs(3))
2867            .with_max_concurrent_tls_handshakes(64);
2868        assert_eq!(cfg.tls_handshake_timeout, Duration::from_secs(3));
2869        assert_eq!(cfg.max_concurrent_tls_handshakes, 64);
2870    }
2871
2872    #[test]
2873    fn validate_rejects_zero_tls_handshake_timeout() {
2874        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
2875            .with_tls_handshake_timeout(Duration::ZERO);
2876        let err = cfg.validate().expect_err("zero handshake timeout");
2877        assert!(err.to_string().contains("tls_handshake_timeout"));
2878    }
2879
2880    #[test]
2881    fn validate_rejects_zero_max_concurrent_tls_handshakes() {
2882        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0")
2883            .with_max_concurrent_tls_handshakes(0);
2884        let err = cfg.validate().expect_err("zero handshake concurrency");
2885        assert!(err.to_string().contains("max_concurrent_tls_handshakes"));
2886    }
2887
2888    #[test]
2889    fn validate_consumes_and_proves() {
2890        // Valid config -> Validated wrapper, original is consumed.
2891        let cfg = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2892        let validated = cfg.validate().expect("valid config");
2893        // as_inner() gives read-only access to inner fields.
2894        assert_eq!(validated.as_inner().name, "test-server");
2895        // into_inner recovers the raw value.
2896        let raw = validated.into_inner();
2897        assert_eq!(raw.name, "test-server");
2898
2899        // Invalid config (zero max_request_body) -> Err.
2900        let mut bad = McpServerConfig::new("127.0.0.1:8080", "test-server", "1.0.0");
2901        bad.max_request_body = 0;
2902        assert!(bad.validate().is_err(), "zero body cap must fail validate");
2903    }
2904
2905    #[test]
2906    fn validate_rejects_zero_max_concurrent_requests() {
2907        let cfg =
2908            McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_max_concurrent_requests(0);
2909        let err = cfg.validate().expect_err("zero concurrency cap must fail");
2910        assert!(
2911            format!("{err}").contains("max_concurrent_requests"),
2912            "error should mention max_concurrent_requests, got: {err}"
2913        );
2914    }
2915
2916    #[test]
2917    fn validate_rejects_zero_max_tracked_keys() {
2918        // Defaults mirror auth::default_max_attempts / default_idle_eviction
2919        // (module-private in auth.rs); spelled out here for review clarity.
2920        let rl = crate::auth::RateLimitConfig {
2921            max_attempts_per_minute: 30,
2922            pre_auth_max_per_minute: None,
2923            max_tracked_keys: 0,
2924            idle_eviction: Duration::from_secs(15 * 60),
2925        };
2926        let auth_cfg = AuthConfig {
2927            enabled: true,
2928            api_keys: Vec::new(),
2929            mtls: None,
2930            rate_limit: Some(rl),
2931            #[cfg(feature = "oauth")]
2932            oauth: None,
2933        };
2934        let cfg = McpServerConfig::new("127.0.0.1:8080", "test", "1.0.0").with_auth(auth_cfg);
2935        let err = cfg.validate().expect_err("zero max_tracked_keys must fail");
2936        assert!(
2937            format!("{err}").contains("max_tracked_keys"),
2938            "error should mention max_tracked_keys, got: {err}"
2939        );
2940    }
2941
2942    #[test]
2943    fn derive_allowed_hosts_includes_public_host() {
2944        let hosts = derive_allowed_hosts("0.0.0.0:8080", Some("https://mcp.example.com/mcp"));
2945        assert!(
2946            hosts.iter().any(|h| h == "mcp.example.com"),
2947            "public_url host must be allowed"
2948        );
2949    }
2950
2951    #[test]
2952    fn derive_allowed_hosts_includes_bind_authority() {
2953        let hosts = derive_allowed_hosts("127.0.0.1:8080", None);
2954        assert!(
2955            hosts.iter().any(|h| h == "127.0.0.1"),
2956            "bind host must be allowed"
2957        );
2958        assert!(
2959            hosts.iter().any(|h| h == "127.0.0.1:8080"),
2960            "bind authority must be allowed"
2961        );
2962    }
2963
2964    // -- healthz --
2965
2966    #[tokio::test]
2967    async fn healthz_returns_ok_json() {
2968        let resp = healthz().await.into_response();
2969        assert_eq!(resp.status(), StatusCode::OK);
2970        let body = resp.into_body().collect().await.unwrap().to_bytes();
2971        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2972        assert_eq!(json["status"], "ok");
2973        assert!(
2974            json.get("name").is_none(),
2975            "healthz must not expose server name"
2976        );
2977        assert!(
2978            json.get("version").is_none(),
2979            "healthz must not expose version"
2980        );
2981    }
2982
2983    // -- readyz --
2984
2985    #[tokio::test]
2986    async fn readyz_returns_ok_when_ready() {
2987        let check: ReadinessCheck =
2988            Arc::new(|| Box::pin(async { serde_json::json!({"ready": true, "db": "connected"}) }));
2989        let resp = readyz(check).await.into_response();
2990        assert_eq!(resp.status(), StatusCode::OK);
2991        let body = resp.into_body().collect().await.unwrap().to_bytes();
2992        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2993        assert_eq!(json["ready"], true);
2994        assert!(
2995            json.get("name").is_none(),
2996            "readyz must not expose server name"
2997        );
2998        assert!(
2999            json.get("version").is_none(),
3000            "readyz must not expose version"
3001        );
3002        assert_eq!(json["db"], "connected");
3003    }
3004
3005    #[tokio::test]
3006    async fn readyz_returns_503_when_not_ready() {
3007        let check: ReadinessCheck =
3008            Arc::new(|| Box::pin(async { serde_json::json!({"ready": false}) }));
3009        let resp = readyz(check).await.into_response();
3010        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3011    }
3012
3013    #[tokio::test]
3014    async fn readyz_returns_503_when_ready_missing() {
3015        let check: ReadinessCheck =
3016            Arc::new(|| Box::pin(async { serde_json::json!({"status": "starting"}) }));
3017        let resp = readyz(check).await.into_response();
3018        // Missing "ready" field defaults to false -> 503
3019        assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
3020    }
3021
3022    // -- origin_check_middleware --
3023
3024    /// Build a test router with origin check middleware and a simple handler.
3025    fn origin_router(origins: Vec<String>, log_request_headers: bool) -> axum::Router {
3026        let allowed: Arc<[String]> = Arc::from(origins);
3027        axum::Router::new()
3028            .route("/test", axum::routing::get(|| async { "ok" }))
3029            .layer(axum::middleware::from_fn(move |req, next| {
3030                let a = Arc::clone(&allowed);
3031                origin_check_middleware(a, log_request_headers, req, next)
3032            }))
3033    }
3034
3035    #[tokio::test]
3036    async fn origin_allowed_passes() {
3037        let app = origin_router(vec!["http://localhost:3000".into()], false);
3038        let req = Request::builder()
3039            .uri("/test")
3040            .header(header::ORIGIN, "http://localhost:3000")
3041            .body(Body::empty())
3042            .unwrap();
3043        let resp = app.oneshot(req).await.unwrap();
3044        assert_eq!(resp.status(), StatusCode::OK);
3045    }
3046
3047    #[tokio::test]
3048    async fn origin_rejected_returns_403() {
3049        let app = origin_router(vec!["http://localhost:3000".into()], false);
3050        let req = Request::builder()
3051            .uri("/test")
3052            .header(header::ORIGIN, "http://evil.com")
3053            .body(Body::empty())
3054            .unwrap();
3055        let resp = app.oneshot(req).await.unwrap();
3056        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3057    }
3058
3059    #[tokio::test]
3060    async fn no_origin_header_passes() {
3061        let app = origin_router(vec!["http://localhost:3000".into()], false);
3062        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3063        let resp = app.oneshot(req).await.unwrap();
3064        assert_eq!(resp.status(), StatusCode::OK);
3065    }
3066
3067    #[tokio::test]
3068    async fn empty_allowlist_rejects_any_origin() {
3069        let app = origin_router(vec![], false);
3070        let req = Request::builder()
3071            .uri("/test")
3072            .header(header::ORIGIN, "http://anything.com")
3073            .body(Body::empty())
3074            .unwrap();
3075        let resp = app.oneshot(req).await.unwrap();
3076        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
3077    }
3078
3079    #[tokio::test]
3080    async fn empty_allowlist_passes_without_origin() {
3081        let app = origin_router(vec![], false);
3082        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3083        let resp = app.oneshot(req).await.unwrap();
3084        assert_eq!(resp.status(), StatusCode::OK);
3085    }
3086
3087    #[test]
3088    fn format_request_headers_redacts_sensitive_values() {
3089        let mut headers = axum::http::HeaderMap::new();
3090        headers.insert("authorization", "Bearer secret-token".parse().unwrap());
3091        headers.insert("cookie", "sid=abc".parse().unwrap());
3092        headers.insert("x-request-id", "req-123".parse().unwrap());
3093
3094        let out = format_request_headers_for_log(&headers);
3095        assert!(out.contains("authorization: [REDACTED]"));
3096        assert!(out.contains("cookie: [REDACTED]"));
3097        assert!(out.contains("x-request-id: req-123"));
3098        assert!(!out.contains("secret-token"));
3099    }
3100
3101    // -- security_headers_middleware --
3102
3103    fn security_router(is_tls: bool) -> axum::Router {
3104        security_router_with(is_tls, SecurityHeadersConfig::default())
3105    }
3106
3107    fn security_router_with(is_tls: bool, cfg: SecurityHeadersConfig) -> axum::Router {
3108        let cfg = Arc::new(cfg);
3109        axum::Router::new()
3110            .route("/test", axum::routing::get(|| async { "ok" }))
3111            .layer(axum::middleware::from_fn(move |req, next| {
3112                let c = Arc::clone(&cfg);
3113                security_headers_middleware(is_tls, c, req, next)
3114            }))
3115    }
3116
3117    #[tokio::test]
3118    async fn security_headers_set_on_response() {
3119        let app = security_router(false);
3120        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3121        let resp = app.oneshot(req).await.unwrap();
3122        assert_eq!(resp.status(), StatusCode::OK);
3123
3124        let h = resp.headers();
3125        assert_eq!(h.get("x-content-type-options").unwrap(), "nosniff");
3126        assert_eq!(h.get("x-frame-options").unwrap(), "deny");
3127        assert_eq!(h.get("cache-control").unwrap(), "no-store, max-age=0");
3128        assert_eq!(h.get("referrer-policy").unwrap(), "no-referrer");
3129        assert_eq!(h.get("cross-origin-opener-policy").unwrap(), "same-origin");
3130        assert_eq!(
3131            h.get("cross-origin-resource-policy").unwrap(),
3132            "same-origin"
3133        );
3134        assert_eq!(
3135            h.get("cross-origin-embedder-policy").unwrap(),
3136            "require-corp"
3137        );
3138        assert_eq!(h.get("x-permitted-cross-domain-policies").unwrap(), "none");
3139        assert!(
3140            h.get("permissions-policy")
3141                .unwrap()
3142                .to_str()
3143                .unwrap()
3144                .contains("camera=()"),
3145            "permissions-policy must restrict browser features"
3146        );
3147        assert_eq!(
3148            h.get("content-security-policy").unwrap(),
3149            "default-src 'none'; frame-ancestors 'none'"
3150        );
3151        assert_eq!(h.get("x-dns-prefetch-control").unwrap(), "off");
3152        // No HSTS when TLS is off.
3153        assert!(h.get("strict-transport-security").is_none());
3154    }
3155
3156    #[tokio::test]
3157    async fn hsts_set_when_tls_enabled() {
3158        let app = security_router(true);
3159        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3160        let resp = app.oneshot(req).await.unwrap();
3161
3162        let hsts = resp.headers().get("strict-transport-security").unwrap();
3163        assert!(
3164            hsts.to_str().unwrap().contains("max-age=63072000"),
3165            "HSTS must set 2-year max-age"
3166        );
3167    }
3168
3169    // -- SecurityHeadersConfig validation + override semantics --
3170
3171    /// Build a minimal config with a custom SecurityHeadersConfig and
3172    /// drive it through `check()`. Returns the result so individual
3173    /// tests can assert on success or specific error messages.
3174    fn check_with_security_headers(headers: SecurityHeadersConfig) -> Result<(), McpxError> {
3175        let cfg =
3176            McpServerConfig::new("127.0.0.1:8080", "test", "0.0.0").with_security_headers(headers);
3177        cfg.check()
3178    }
3179
3180    #[test]
3181    fn security_headers_config_default_validates() {
3182        check_with_security_headers(SecurityHeadersConfig::default())
3183            .expect("default SecurityHeadersConfig must validate");
3184    }
3185
3186    #[test]
3187    fn security_headers_config_validate_accepts_empty_string() {
3188        // All twelve fields explicitly set to "" -> omit-everything mode.
3189        let h = SecurityHeadersConfig {
3190            x_content_type_options: Some(String::new()),
3191            x_frame_options: Some(String::new()),
3192            cache_control: Some(String::new()),
3193            referrer_policy: Some(String::new()),
3194            cross_origin_opener_policy: Some(String::new()),
3195            cross_origin_resource_policy: Some(String::new()),
3196            cross_origin_embedder_policy: Some(String::new()),
3197            permissions_policy: Some(String::new()),
3198            x_permitted_cross_domain_policies: Some(String::new()),
3199            content_security_policy: Some(String::new()),
3200            x_dns_prefetch_control: Some(String::new()),
3201            strict_transport_security: Some(String::new()),
3202        };
3203        check_with_security_headers(h).expect("Some(\"\") on every field must validate (omit-all)");
3204    }
3205
3206    #[test]
3207    fn security_headers_config_validate_rejects_bad_value() {
3208        // 0x07 (BEL) is not a valid HTTP header value char.
3209        let h = SecurityHeadersConfig {
3210            referrer_policy: Some("\u{0007}".into()),
3211            ..SecurityHeadersConfig::default()
3212        };
3213        let err = check_with_security_headers(h)
3214            .expect_err("control char in referrer_policy must reject");
3215        let msg = err.to_string();
3216        assert!(
3217            msg.contains("referrer_policy"),
3218            "error must name the offending field, got: {msg}"
3219        );
3220    }
3221
3222    #[test]
3223    fn security_headers_config_validate_rejects_hsts_preload() {
3224        let h = SecurityHeadersConfig {
3225            strict_transport_security: Some("max-age=63072000; includeSubDomains; preload".into()),
3226            ..SecurityHeadersConfig::default()
3227        };
3228        let err = check_with_security_headers(h).expect_err("HSTS with preload must reject");
3229        let msg = err.to_string();
3230        assert!(
3231            msg.contains("strict_transport_security"),
3232            "error must name the field, got: {msg}"
3233        );
3234        assert!(
3235            msg.to_lowercase().contains("preload"),
3236            "error must mention `preload`, got: {msg}"
3237        );
3238    }
3239
3240    #[test]
3241    fn security_headers_config_validate_rejects_hsts_preload_uppercase() {
3242        // Case-insensitive match.
3243        let h = SecurityHeadersConfig {
3244            strict_transport_security: Some("max-age=600; PRELOAD".into()),
3245            ..SecurityHeadersConfig::default()
3246        };
3247        check_with_security_headers(h).expect_err("HSTS preload check must be case-insensitive");
3248    }
3249
3250    #[tokio::test]
3251    async fn security_headers_override_honored() {
3252        // Override X-Frame-Options to SAMEORIGIN.
3253        let h = SecurityHeadersConfig {
3254            x_frame_options: Some("SAMEORIGIN".into()),
3255            ..SecurityHeadersConfig::default()
3256        };
3257        let app = security_router_with(false, h);
3258        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3259        let resp = app.oneshot(req).await.unwrap();
3260        assert_eq!(resp.status(), StatusCode::OK);
3261
3262        let xfo = resp.headers().get("x-frame-options").unwrap();
3263        assert_eq!(xfo, "SAMEORIGIN");
3264    }
3265
3266    #[tokio::test]
3267    async fn security_headers_empty_string_omits() {
3268        // Empty string on referrer-policy -> header absent.
3269        let h = SecurityHeadersConfig {
3270            referrer_policy: Some(String::new()),
3271            ..SecurityHeadersConfig::default()
3272        };
3273        let app = security_router_with(false, h);
3274        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3275        let resp = app.oneshot(req).await.unwrap();
3276        assert_eq!(resp.status(), StatusCode::OK);
3277
3278        assert!(
3279            resp.headers().get("referrer-policy").is_none(),
3280            "Some(\"\") must omit the header"
3281        );
3282        // Other defaults should still be present.
3283        assert_eq!(
3284            resp.headers().get("x-content-type-options").unwrap(),
3285            "nosniff"
3286        );
3287    }
3288
3289    #[tokio::test]
3290    async fn security_headers_hsts_only_when_tls() {
3291        // HSTS override is irrelevant when TLS is off.
3292        let h = SecurityHeadersConfig {
3293            strict_transport_security: Some("max-age=600".into()),
3294            ..SecurityHeadersConfig::default()
3295        };
3296        let app = security_router_with(false, h);
3297        let req = Request::builder().uri("/test").body(Body::empty()).unwrap();
3298        let resp = app.oneshot(req).await.unwrap();
3299        assert!(
3300            resp.headers().get("strict-transport-security").is_none(),
3301            "HSTS must remain absent on plaintext deployments even with override"
3302        );
3303    }
3304
3305    // -- oauth_token_cache_headers_middleware --
3306
3307    #[cfg(feature = "oauth")]
3308    #[tokio::test]
3309    async fn oauth_token_cache_headers_set_pragma_and_vary() {
3310        let app = axum::Router::new()
3311            .route("/token", axum::routing::post(|| async { "{}" }))
3312            .layer(axum::middleware::from_fn(
3313                oauth_token_cache_headers_middleware,
3314            ));
3315        let req = Request::builder()
3316            .method("POST")
3317            .uri("/token")
3318            .body(Body::from("{}"))
3319            .unwrap();
3320        let resp = app.oneshot(req).await.unwrap();
3321        assert_eq!(resp.status(), StatusCode::OK);
3322
3323        let h = resp.headers();
3324        assert_eq!(
3325            h.get("pragma").unwrap(),
3326            "no-cache",
3327            "RFC 6749 §5.1: token responses must set Pragma: no-cache"
3328        );
3329        let vary_values: Vec<String> = h
3330            .get_all("vary")
3331            .iter()
3332            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3333            .collect();
3334        assert!(
3335            vary_values
3336                .iter()
3337                .any(|v| v.eq_ignore_ascii_case("Authorization")),
3338            "RFC 6750 §5.4: Vary must include Authorization, got {vary_values:?}"
3339        );
3340    }
3341
3342    #[cfg(feature = "oauth")]
3343    #[tokio::test]
3344    async fn oauth_token_cache_headers_preserve_existing_vary() {
3345        // Simulates a handler/layer that already set `Vary: Accept-Encoding`
3346        // (e.g. compression). Our middleware must APPEND, not REPLACE.
3347        let app = axum::Router::new()
3348            .route(
3349                "/token",
3350                axum::routing::post(|| async {
3351                    axum::response::Response::builder()
3352                        .header("vary", "Accept-Encoding")
3353                        .body(axum::body::Body::from("{}"))
3354                        .unwrap()
3355                }),
3356            )
3357            .layer(axum::middleware::from_fn(
3358                oauth_token_cache_headers_middleware,
3359            ));
3360        let req = Request::builder()
3361            .method("POST")
3362            .uri("/token")
3363            .body(Body::empty())
3364            .unwrap();
3365        let resp = app.oneshot(req).await.unwrap();
3366
3367        let vary: Vec<String> = resp
3368            .headers()
3369            .get_all("vary")
3370            .iter()
3371            .filter_map(|v| v.to_str().ok().map(str::to_owned))
3372            .collect();
3373        assert!(
3374            vary.iter().any(|v| v.contains("Accept-Encoding")),
3375            "must preserve pre-existing Vary value, got {vary:?}"
3376        );
3377        assert!(
3378            vary.iter().any(|v| v.contains("Authorization")),
3379            "must append Authorization to Vary, got {vary:?}"
3380        );
3381    }
3382
3383    // -- version endpoint --
3384
3385    #[test]
3386    fn version_payload_contains_expected_fields() {
3387        let v = version_payload("my-server", "1.2.3");
3388        assert_eq!(v["name"], "my-server");
3389        assert_eq!(v["version"], "1.2.3");
3390        assert!(v["build_git_sha"].is_string());
3391        assert!(v["build_timestamp"].is_string());
3392        assert!(v["rust_version"].is_string());
3393        assert!(v["mcpx_version"].is_string());
3394    }
3395
3396    // -- concurrency limit layer --
3397
3398    #[tokio::test]
3399    async fn concurrency_limit_layer_composes_and_serves() {
3400        // We only assert the layer stack compiles and a single request
3401        // below the cap still succeeds. True back-pressure behaviour
3402        // requires a live HTTP server and is covered by integration tests.
3403        let app = axum::Router::new()
3404            .route("/ok", axum::routing::get(|| async { "ok" }))
3405            .layer(
3406                tower::ServiceBuilder::new()
3407                    .layer(axum::error_handling::HandleErrorLayer::new(
3408                        |_err: tower::BoxError| async { StatusCode::SERVICE_UNAVAILABLE },
3409                    ))
3410                    .layer(tower::load_shed::LoadShedLayer::new())
3411                    .layer(tower::limit::ConcurrencyLimitLayer::new(4)),
3412            );
3413        let resp = app
3414            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
3415            .await
3416            .unwrap();
3417        assert_eq!(resp.status(), StatusCode::OK);
3418    }
3419
3420    // -- compression layer --
3421
3422    #[tokio::test]
3423    async fn compression_layer_gzip_encodes_response() {
3424        use tower_http::compression::Predicate as _;
3425
3426        let big_body = "a".repeat(4096);
3427        let app = axum::Router::new()
3428            .route(
3429                "/big",
3430                axum::routing::get(move || {
3431                    let body = big_body.clone();
3432                    async move { body }
3433                }),
3434            )
3435            .layer(
3436                tower_http::compression::CompressionLayer::new()
3437                    .gzip(true)
3438                    .br(true)
3439                    .compress_when(
3440                        tower_http::compression::DefaultPredicate::new()
3441                            .and(tower_http::compression::predicate::SizeAbove::new(1024)),
3442                    ),
3443            );
3444
3445        let req = Request::builder()
3446            .uri("/big")
3447            .header(header::ACCEPT_ENCODING, "gzip")
3448            .body(Body::empty())
3449            .unwrap();
3450        let resp = app.oneshot(req).await.unwrap();
3451        assert_eq!(resp.status(), StatusCode::OK);
3452        assert_eq!(
3453            resp.headers().get(header::CONTENT_ENCODING).unwrap(),
3454            "gzip"
3455        );
3456    }
3457
3458    // -- TlsListener handshake timeout --
3459
3460    #[tokio::test]
3461    async fn tls_handshake_timeout_reaps_idle_connections() {
3462        use tokio::io::AsyncReadExt as _;
3463
3464        let _ = rustls::crypto::ring::default_provider().install_default();
3465
3466        // Self-signed cert material on disk (TlsListener::new takes paths).
3467        let key = rcgen::KeyPair::generate().expect("generate key");
3468        let cert = rcgen::CertificateParams::new(vec!["localhost".to_owned()])
3469            .expect("cert params")
3470            .self_signed(&key)
3471            .expect("self-signed cert");
3472        let dir = std::env::temp_dir().join(format!(
3473            "rmcp-server-kit-hs-timeout-{}",
3474            std::time::SystemTime::now()
3475                .duration_since(std::time::UNIX_EPOCH)
3476                .expect("clock after epoch")
3477                .as_nanos()
3478        ));
3479        tokio::fs::create_dir_all(&dir).await.expect("temp dir");
3480        let cert_path = dir.join("server.crt");
3481        let key_path = dir.join("server.key");
3482        tokio::fs::write(&cert_path, cert.pem())
3483            .await
3484            .expect("write cert");
3485        tokio::fs::write(&key_path, key.serialize_pem())
3486            .await
3487            .expect("write key");
3488
3489        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
3490        let tls = TlsListener::new(
3491            listener,
3492            &cert_path,
3493            &key_path,
3494            None,
3495            None,
3496            Duration::from_millis(200),
3497            8, // custom concurrency cap: proves the plumbing end-to-end
3498        )
3499        .expect("tls listener");
3500        let addr = axum::serve::Listener::local_addr(&tls).expect("local addr");
3501
3502        // Connect and send NOTHING: the handshake worker must time out
3503        // after 200ms and drop the stream, which the client observes as
3504        // EOF or a reset well within the 2s deadline.
3505        let mut idle = tokio::net::TcpStream::connect(addr).await.expect("connect");
3506        let mut buf = [0_u8; 16];
3507        let read = tokio::time::timeout(Duration::from_secs(2), idle.read(&mut buf))
3508            .await
3509            .expect("server must reap the idle handshake within its timeout");
3510        match read {
3511            Ok(0) | Err(_) => {} // EOF or reset: connection was dropped.
3512            Ok(n) => panic!("unexpected {n} bytes from server during reaped handshake"),
3513        }
3514
3515        drop(tls);
3516    }
3517}