Skip to main content

rmcp_server_kit/
auth.rs

1//! Authentication middleware for MCP servers.
2//!
3//! Supports multiple authentication methods tried in priority order:
4//! 1. mTLS client certificate (if configured and peer cert present)
5//! 2. Bearer token (API key) with Argon2id hash verification
6//!
7//! Includes per-source-IP rate limiting on authentication attempts.
8
9use std::{
10    collections::HashSet,
11    net::{IpAddr, SocketAddr},
12    num::NonZeroU32,
13    path::PathBuf,
14    sync::{
15        Arc, LazyLock, Mutex,
16        atomic::{AtomicU64, Ordering},
17    },
18    time::Duration,
19};
20
21use arc_swap::ArcSwap;
22use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
23use axum::{
24    body::Body,
25    extract::ConnectInfo,
26    http::{Request, header},
27    middleware::Next,
28    response::{IntoResponse, Response},
29};
30use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
31use secrecy::SecretString;
32use serde::Deserialize;
33use x509_parser::prelude::*;
34
35use crate::{bounded_limiter::BoundedKeyedLimiter, error::McpxError};
36
37/// Identity of an authenticated caller.
38///
39/// The [`Debug`] impl is **manually written** to redact the raw bearer token
40/// and the JWT `sub` claim. This prevents accidental disclosure if an
41/// `AuthIdentity` is ever logged via `tracing::debug!(?identity, …)` or
42/// `format!("{identity:?}")`. Only `name`, `role`, and `method` are printed
43/// in the clear; `raw_token` and `sub` are rendered as `<redacted>` /
44/// `<present>` / `<none>` markers.
45#[derive(Clone)]
46#[non_exhaustive]
47pub struct AuthIdentity {
48    /// Human-readable identity name (e.g. API key label or cert CN).
49    pub name: String,
50    /// RBAC role associated with this identity.
51    pub role: String,
52    /// Which authentication mechanism produced this identity.
53    pub method: AuthMethod,
54    /// Raw bearer token from the `Authorization` header, wrapped in
55    /// [`SecretString`] so it is never accidentally logged or serialized.
56    /// Present for OAuth JWT; `None` for mTLS and API-key auth.
57    /// Tool handlers use this for downstream token passthrough via
58    /// [`crate::rbac::current_token`].
59    pub raw_token: Option<SecretString>,
60    /// JWT `sub` claim (stable user identifier, e.g. Keycloak UUID).
61    /// Used for token store keying. `None` for non-JWT auth.
62    pub sub: Option<String>,
63}
64
65impl std::fmt::Debug for AuthIdentity {
66    /// Redacts `raw_token` and `sub` to prevent secret leakage via
67    /// `format!("{:?}")` or `tracing::debug!(?identity)`.
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.debug_struct("AuthIdentity")
70            .field("name", &self.name)
71            .field("role", &self.role)
72            .field("method", &self.method)
73            .field(
74                "raw_token",
75                &if self.raw_token.is_some() {
76                    "<redacted>"
77                } else {
78                    "<none>"
79                },
80            )
81            .field(
82                "sub",
83                &if self.sub.is_some() {
84                    "<redacted>"
85                } else {
86                    "<none>"
87                },
88            )
89            .finish()
90    }
91}
92
93/// How the caller authenticated.
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95#[non_exhaustive]
96pub enum AuthMethod {
97    /// Bearer API key (Argon2id-hashed, configured statically).
98    BearerToken,
99    /// Mutual TLS client certificate.
100    MtlsCertificate,
101    /// OAuth 2.1 JWT bearer token (validated via JWKS).
102    OAuthJwt,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106enum AuthFailureClass {
107    MissingCredential,
108    InvalidCredential,
109    #[cfg_attr(not(feature = "oauth"), allow(dead_code))]
110    ExpiredCredential,
111    /// Source IP exceeded the post-failure backoff limit.
112    RateLimited,
113    /// Source IP exceeded the pre-auth abuse gate (rejected before any
114    /// password-hash work — see [`AuthState::pre_auth_limiter`]).
115    PreAuthGate,
116}
117
118impl AuthFailureClass {
119    fn as_str(self) -> &'static str {
120        match self {
121            Self::MissingCredential => "missing_credential",
122            Self::InvalidCredential => "invalid_credential",
123            Self::ExpiredCredential => "expired_credential",
124            Self::RateLimited => "rate_limited",
125            Self::PreAuthGate => "pre_auth_gate",
126        }
127    }
128
129    fn bearer_error(self) -> (&'static str, &'static str) {
130        match self {
131            Self::MissingCredential => (
132                "invalid_request",
133                "missing bearer token or mTLS client certificate",
134            ),
135            Self::InvalidCredential => ("invalid_token", "token is invalid"),
136            Self::ExpiredCredential => ("invalid_token", "token is expired"),
137            Self::RateLimited => ("invalid_request", "too many failed authentication attempts"),
138            Self::PreAuthGate => (
139                "invalid_request",
140                "too many unauthenticated requests from this source",
141            ),
142        }
143    }
144
145    fn response_body(self) -> &'static str {
146        match self {
147            Self::MissingCredential => "unauthorized: missing credential",
148            Self::InvalidCredential => "unauthorized: invalid credential",
149            Self::ExpiredCredential => "unauthorized: expired credential",
150            Self::RateLimited => "rate limited",
151            Self::PreAuthGate => "rate limited (pre-auth)",
152        }
153    }
154}
155
156/// Snapshot of authentication success/failure counters.
157#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
158#[non_exhaustive]
159pub struct AuthCountersSnapshot {
160    /// Successful mTLS authentications.
161    pub success_mtls: u64,
162    /// Successful bearer-token authentications.
163    pub success_bearer: u64,
164    /// Successful OAuth JWT authentications.
165    pub success_oauth_jwt: u64,
166    /// Failures because no credential was presented.
167    pub failure_missing_credential: u64,
168    /// Failures because the credential was malformed or wrong.
169    pub failure_invalid_credential: u64,
170    /// Failures because the credential had expired.
171    pub failure_expired_credential: u64,
172    /// Failures because the source IP was rate-limited (post-failure backoff).
173    pub failure_rate_limited: u64,
174    /// Failures because the source IP exceeded the pre-auth abuse gate.
175    /// These never reach the password-hash verification path.
176    pub failure_pre_auth_gate: u64,
177}
178
179/// Internal atomic counters backing [`AuthCountersSnapshot`].
180#[derive(Debug, Default)]
181pub(crate) struct AuthCounters {
182    success_mtls: AtomicU64,
183    success_bearer: AtomicU64,
184    success_oauth_jwt: AtomicU64,
185    failure_missing_credential: AtomicU64,
186    failure_invalid_credential: AtomicU64,
187    failure_expired_credential: AtomicU64,
188    failure_rate_limited: AtomicU64,
189    failure_pre_auth_gate: AtomicU64,
190}
191
192impl AuthCounters {
193    fn record_success(&self, method: AuthMethod) {
194        match method {
195            AuthMethod::MtlsCertificate => {
196                self.success_mtls.fetch_add(1, Ordering::Relaxed);
197            }
198            AuthMethod::BearerToken => {
199                self.success_bearer.fetch_add(1, Ordering::Relaxed);
200            }
201            AuthMethod::OAuthJwt => {
202                self.success_oauth_jwt.fetch_add(1, Ordering::Relaxed);
203            }
204        }
205    }
206
207    fn record_failure(&self, class: AuthFailureClass) {
208        match class {
209            AuthFailureClass::MissingCredential => {
210                self.failure_missing_credential
211                    .fetch_add(1, Ordering::Relaxed);
212            }
213            AuthFailureClass::InvalidCredential => {
214                self.failure_invalid_credential
215                    .fetch_add(1, Ordering::Relaxed);
216            }
217            AuthFailureClass::ExpiredCredential => {
218                self.failure_expired_credential
219                    .fetch_add(1, Ordering::Relaxed);
220            }
221            AuthFailureClass::RateLimited => {
222                self.failure_rate_limited.fetch_add(1, Ordering::Relaxed);
223            }
224            AuthFailureClass::PreAuthGate => {
225                self.failure_pre_auth_gate.fetch_add(1, Ordering::Relaxed);
226            }
227        }
228    }
229
230    fn snapshot(&self) -> AuthCountersSnapshot {
231        AuthCountersSnapshot {
232            success_mtls: self.success_mtls.load(Ordering::Relaxed),
233            success_bearer: self.success_bearer.load(Ordering::Relaxed),
234            success_oauth_jwt: self.success_oauth_jwt.load(Ordering::Relaxed),
235            failure_missing_credential: self.failure_missing_credential.load(Ordering::Relaxed),
236            failure_invalid_credential: self.failure_invalid_credential.load(Ordering::Relaxed),
237            failure_expired_credential: self.failure_expired_credential.load(Ordering::Relaxed),
238            failure_rate_limited: self.failure_rate_limited.load(Ordering::Relaxed),
239            failure_pre_auth_gate: self.failure_pre_auth_gate.load(Ordering::Relaxed),
240        }
241    }
242}
243
244/// RFC 3339 timestamp, parsed at deserialization time.
245///
246/// Use this for any public field that needs to carry an RFC 3339 timestamp from
247/// TOML/JSON config or builder APIs. Construction is fallible (`parse`); once
248/// constructed the value is guaranteed to be a real RFC 3339 timestamp with a
249/// known offset, so downstream code does not need to handle parse errors.
250///
251/// Wraps [`chrono::DateTime<chrono::FixedOffset>`]; the underlying value is
252/// available via [`Self::as_datetime`] or [`Self::into_inner`]. `Serialize`
253/// emits the canonical RFC 3339 form via [`chrono::DateTime::to_rfc3339`], so
254/// the on-the-wire format for `ApiKeySummary` (admin endpoints) is unchanged.
255#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
256#[non_exhaustive]
257pub struct RfcTimestamp(chrono::DateTime<chrono::FixedOffset>);
258
259impl RfcTimestamp {
260    /// Parse an RFC 3339 timestamp.
261    ///
262    /// # Errors
263    ///
264    /// Returns the underlying [`chrono::ParseError`] when `s` is not a valid
265    /// RFC 3339 timestamp (e.g. missing the `T` separator, missing the offset
266    /// suffix, or out-of-range fields).
267    pub fn parse(s: &str) -> Result<Self, chrono::ParseError> {
268        chrono::DateTime::parse_from_rfc3339(s).map(Self)
269    }
270
271    /// Borrow the underlying [`chrono::DateTime`].
272    #[must_use]
273    pub fn as_datetime(&self) -> &chrono::DateTime<chrono::FixedOffset> {
274        &self.0
275    }
276
277    /// Consume the wrapper and return the underlying [`chrono::DateTime`].
278    #[must_use]
279    pub fn into_inner(self) -> chrono::DateTime<chrono::FixedOffset> {
280        self.0
281    }
282}
283
284impl std::fmt::Display for RfcTimestamp {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        // Canonical RFC 3339 form; matches the deserialization input contract.
287        write!(f, "{}", self.0.to_rfc3339())
288    }
289}
290
291impl std::fmt::Debug for RfcTimestamp {
292    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293        // Render as the canonical RFC 3339 string (not chrono's internal
294        // debug form) so existing `ApiKeyEntry` Debug-redaction tests --
295        // which look for the literal `"2030-01-01T00:00:00Z"` form in the
296        // formatted output -- continue to hold without bespoke handling.
297        write!(f, "{}", self.0.to_rfc3339())
298    }
299}
300
301impl<'de> Deserialize<'de> for RfcTimestamp {
302    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
303    where
304        D: serde::Deserializer<'de>,
305    {
306        // Validate at deserialization time: a malformed `expires_at` in
307        // TOML or JSON aborts config load with a clear serde error rather
308        // than silently producing a key that fails open at runtime.
309        let s = String::deserialize(deserializer)?;
310        Self::parse(&s).map_err(serde::de::Error::custom)
311    }
312}
313
314impl serde::Serialize for RfcTimestamp {
315    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
316    where
317        S: serde::Serializer,
318    {
319        serializer.serialize_str(&self.0.to_rfc3339())
320    }
321}
322
323impl From<chrono::DateTime<chrono::FixedOffset>> for RfcTimestamp {
324    fn from(value: chrono::DateTime<chrono::FixedOffset>) -> Self {
325        Self(value)
326    }
327}
328
329/// A single API key entry (stored as Argon2id hash in config).
330///
331/// The [`Debug`] impl is **manually written** to redact the Argon2id hash.
332/// Although the hash is not directly reversible, treating it as a secret
333/// prevents offline brute-force attempts from leaked logs and matches the
334/// defense-in-depth posture used for [`AuthIdentity`].
335#[derive(Clone, Deserialize)]
336#[non_exhaustive]
337pub struct ApiKeyEntry {
338    /// Human-readable key label (used in logs and audit records).
339    pub name: String,
340    /// Argon2id hash of the token (PHC string format).
341    pub hash: String,
342    /// RBAC role granted when this key authenticates successfully.
343    pub role: String,
344    /// Optional expiry, parsed from an RFC 3339 string at deserialization
345    /// time. Construction from a raw string is fallible (see
346    /// [`RfcTimestamp::parse`] and [`ApiKeyEntry::try_with_expiry`]),
347    /// which guarantees `verify_bearer_token` never sees a malformed value.
348    pub expires_at: Option<RfcTimestamp>,
349}
350
351impl std::fmt::Debug for ApiKeyEntry {
352    /// Redacts the Argon2id `hash` to keep it out of logs, panic backtraces,
353    /// and admin-endpoint responses that might `format!("{:?}", …)` an entry.
354    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        f.debug_struct("ApiKeyEntry")
356            .field("name", &self.name)
357            .field("hash", &"<redacted>")
358            .field("role", &self.role)
359            .field("expires_at", &self.expires_at)
360            .finish()
361    }
362}
363
364impl ApiKeyEntry {
365    /// Create a new API key entry (no expiry).
366    #[must_use]
367    pub fn new(name: impl Into<String>, hash: impl Into<String>, role: impl Into<String>) -> Self {
368        Self {
369            name: name.into(),
370            hash: hash.into(),
371            role: role.into(),
372            expires_at: None,
373        }
374    }
375
376    /// Set an RFC 3339 expiry on this key.
377    ///
378    /// Takes an already-parsed [`RfcTimestamp`]; for ergonomic construction
379    /// from a raw string see [`Self::try_with_expiry`].
380    #[must_use]
381    pub fn with_expiry(mut self, expires_at: RfcTimestamp) -> Self {
382        self.expires_at = Some(expires_at);
383        self
384    }
385
386    /// Set an RFC 3339 expiry on this key from a raw string.
387    ///
388    /// # Errors
389    ///
390    /// Returns the underlying [`chrono::ParseError`] when `expires_at` is
391    /// not a valid RFC 3339 timestamp. This is the fallible counterpart to
392    /// [`Self::with_expiry`].
393    pub fn try_with_expiry(
394        mut self,
395        expires_at: impl AsRef<str>,
396    ) -> Result<Self, chrono::ParseError> {
397        self.expires_at = Some(RfcTimestamp::parse(expires_at.as_ref())?);
398        Ok(self)
399    }
400}
401
402/// mTLS client certificate authentication configuration.
403#[derive(Debug, Clone, Deserialize)]
404#[allow(
405    clippy::struct_excessive_bools,
406    reason = "mTLS CRL behavior is intentionally configured as independent booleans"
407)]
408#[non_exhaustive]
409pub struct MtlsConfig {
410    /// Path to CA certificate(s) for verifying client certs (PEM format).
411    pub ca_cert_path: PathBuf,
412    /// If true, clients MUST present a valid certificate.
413    /// If false, client certs are optional (verified if presented).
414    #[serde(default)]
415    pub required: bool,
416    /// Default RBAC role for mTLS-authenticated clients.
417    /// The client cert CN becomes the identity name.
418    #[serde(default = "default_mtls_role")]
419    pub default_role: String,
420    /// Enable CRL-based certificate revocation checks using CDP URLs from the
421    /// configured CA chain and connecting client certificates.
422    #[serde(default = "default_true")]
423    pub crl_enabled: bool,
424    /// Optional fixed refresh interval for known CRLs. When omitted, refresh
425    /// cadence is derived from `nextUpdate` and clamped internally.
426    #[serde(default, with = "humantime_serde::option")]
427    pub crl_refresh_interval: Option<Duration>,
428    /// Timeout for individual CRL fetches.
429    #[serde(default = "default_crl_fetch_timeout", with = "humantime_serde")]
430    pub crl_fetch_timeout: Duration,
431    /// Grace window during which stale CRLs may still be used when refresh
432    /// attempts fail.
433    #[serde(default = "default_crl_stale_grace", with = "humantime_serde")]
434    pub crl_stale_grace: Duration,
435    /// When true, missing or unavailable CRLs cause revocation checks to fail
436    /// closed.
437    #[serde(default)]
438    pub crl_deny_on_unavailable: bool,
439    /// When true, apply revocation checks only to the end-entity certificate.
440    #[serde(default)]
441    pub crl_end_entity_only: bool,
442    /// Allow HTTP CRL distribution-point URLs in addition to HTTPS.
443    ///
444    /// Defaults to `true` because RFC 5280 §4.2.1.13 designates HTTP (and
445    /// LDAP) as the canonical transport for CRL distribution points.
446    /// SSRF defense for HTTP CDPs is provided by the IP-allowlist guard
447    /// (private/loopback/link-local/multicast/cloud-metadata addresses are
448    /// always rejected), redirect=none, body-size cap, and per-host
449    /// concurrency limit -- not by forcing HTTPS.
450    #[serde(default = "default_true")]
451    pub crl_allow_http: bool,
452    /// Enforce CRL expiration during certificate validation.
453    #[serde(default = "default_true")]
454    pub crl_enforce_expiration: bool,
455    /// Maximum concurrent CRL fetches across all hosts. Defense in depth
456    /// against SSRF amplification: even if many CDPs are discovered, no
457    /// more than this many fetches run in parallel. Per-host concurrency
458    /// is independently capped at 1 regardless of this value.
459    /// Default: `4`.
460    #[serde(default = "default_crl_max_concurrent_fetches")]
461    pub crl_max_concurrent_fetches: usize,
462    /// Hard cap on each CRL response body in bytes. Fetches exceeding this
463    /// are aborted mid-stream to bound memory and prevent gzip-bomb-style
464    /// amplification. Default: 5 MiB (`5 * 1024 * 1024`).
465    #[serde(default = "default_crl_max_response_bytes")]
466    pub crl_max_response_bytes: u64,
467    /// Global CDP discovery rate limit, in URLs per minute. Throttles
468    /// how many *new* CDP URLs the verifier may admit into the fetch
469    /// pipeline across the whole process, bounding asymmetric `DoS`
470    /// amplification when attacker-controlled certificates carry large
471    /// CDP lists. The limit is global (not per-source-IP) in this
472    /// release; per-IP scoping is deferred to a future version because
473    /// it requires plumbing the peer `SocketAddr` through the verifier
474    /// hook. URLs that lose the rate-limiter race are *not* marked as
475    /// seen, so subsequent handshakes observing the same URL can
476    /// retry admission.
477    /// Default: `60`.
478    #[serde(default = "default_crl_discovery_rate_per_min")]
479    pub crl_discovery_rate_per_min: u32,
480    /// Maximum number of distinct hosts that may hold a CRL fetch
481    /// semaphore at any time. Requests that would grow the map beyond
482    /// this cap return [`McpxError::Config`] containing the literal
483    /// substring `"crl_host_semaphore_cap_exceeded"`. Bounds memory
484    /// growth from attacker-controlled CDP URLs pointing at unique
485    /// hostnames. Default: 1024.
486    #[serde(default = "default_crl_max_host_semaphores")]
487    pub crl_max_host_semaphores: usize,
488    /// Maximum number of distinct URLs tracked in the "seen" set.
489    /// Beyond this, additional discovered URLs are silently dropped
490    /// with a rate-limited warn! log; no error surfaces. Default: 4096.
491    #[serde(default = "default_crl_max_seen_urls")]
492    pub crl_max_seen_urls: usize,
493    /// Maximum number of cached CRL entries. Beyond this, new
494    /// successful fetches are silently dropped with a rate-limited
495    /// warn! log (newest-rejected, not LRU-evicted). Default: 1024.
496    #[serde(default = "default_crl_max_cache_entries")]
497    pub crl_max_cache_entries: usize,
498}
499
500fn default_mtls_role() -> String {
501    "viewer".into()
502}
503
504const fn default_true() -> bool {
505    true
506}
507
508const fn default_crl_fetch_timeout() -> Duration {
509    Duration::from_secs(30)
510}
511
512const fn default_crl_stale_grace() -> Duration {
513    Duration::from_hours(24)
514}
515
516const fn default_crl_max_concurrent_fetches() -> usize {
517    4
518}
519
520const fn default_crl_max_response_bytes() -> u64 {
521    5 * 1024 * 1024
522}
523
524const fn default_crl_discovery_rate_per_min() -> u32 {
525    60
526}
527
528const fn default_crl_max_host_semaphores() -> usize {
529    1024
530}
531
532const fn default_crl_max_seen_urls() -> usize {
533    4096
534}
535
536const fn default_crl_max_cache_entries() -> usize {
537    1024
538}
539
540/// Rate limiting configuration for authentication attempts.
541///
542/// rmcp-server-kit uses two independent per-IP token-bucket limiters for auth:
543///
544/// 1. **Pre-auth abuse gate** ([`Self::pre_auth_max_per_minute`]): consulted
545///    *before* any password-hash work. Throttles unauthenticated traffic from
546///    a single source IP so an attacker cannot pin the CPU on Argon2id by
547///    spraying invalid bearer tokens. Sized generously (default = 10× the
548///    post-failure quota) so legitimate clients are unaffected. mTLS-
549///    authenticated connections bypass this gate entirely (the TLS handshake
550///    already performed expensive crypto with a verified peer).
551/// 2. **Post-failure backoff** ([`Self::max_attempts_per_minute`]): consulted
552///    *after* an authentication attempt fails. Provides explicit backpressure
553///    on bad credentials.
554#[derive(Debug, Clone, Deserialize)]
555#[non_exhaustive]
556pub struct RateLimitConfig {
557    /// Maximum failed authentication attempts per source IP per minute.
558    /// Successful authentications do not consume this budget.
559    #[serde(default = "default_max_attempts")]
560    pub max_attempts_per_minute: u32,
561    /// Maximum *unauthenticated* requests per source IP per minute admitted
562    /// to the password-hash verification path. When `None`, defaults to
563    /// `max_attempts_per_minute * 10` at limiter-construction time.
564    ///
565    /// Set higher than [`Self::max_attempts_per_minute`] so honest clients
566    /// retrying with the wrong key never trip this gate; its purpose is only
567    /// to bound CPU usage under spray attacks.
568    #[serde(default)]
569    pub pre_auth_max_per_minute: Option<u32>,
570    /// Hard cap on the number of distinct source IPs tracked per limiter.
571    /// When reached, idle entries are pruned first; if still full, the
572    /// oldest (LRU) entry is evicted to make room for the new one. This
573    /// bounds memory under IP-spray attacks. Default: `10_000`.
574    #[serde(default = "default_max_tracked_keys")]
575    pub max_tracked_keys: usize,
576    /// Per-IP entries idle for longer than this are eligible for
577    /// opportunistic pruning. Default: 15 minutes.
578    #[serde(default = "default_idle_eviction", with = "humantime_serde")]
579    pub idle_eviction: Duration,
580}
581
582impl Default for RateLimitConfig {
583    fn default() -> Self {
584        Self {
585            max_attempts_per_minute: default_max_attempts(),
586            pre_auth_max_per_minute: None,
587            max_tracked_keys: default_max_tracked_keys(),
588            idle_eviction: default_idle_eviction(),
589        }
590    }
591}
592
593impl RateLimitConfig {
594    /// Create a rate limit config with the given max failed attempts per minute.
595    /// Pre-auth gate defaults to `10x` this value at limiter-construction time.
596    /// Memory-bound defaults are `10_000` tracked keys with 15-minute idle eviction.
597    #[must_use]
598    pub fn new(max_attempts_per_minute: u32) -> Self {
599        Self {
600            max_attempts_per_minute,
601            ..Self::default()
602        }
603    }
604
605    /// Override the pre-auth abuse-gate quota (per source IP per minute).
606    /// When unset, defaults to `max_attempts_per_minute * 10`.
607    #[must_use]
608    pub fn with_pre_auth_max_per_minute(mut self, quota: u32) -> Self {
609        self.pre_auth_max_per_minute = Some(quota);
610        self
611    }
612
613    /// Override the per-limiter cap on tracked source-IP keys (default `10_000`).
614    #[must_use]
615    pub fn with_max_tracked_keys(mut self, max: usize) -> Self {
616        self.max_tracked_keys = max;
617        self
618    }
619
620    /// Override the idle-eviction window (default 15 minutes).
621    #[must_use]
622    pub fn with_idle_eviction(mut self, idle: Duration) -> Self {
623        self.idle_eviction = idle;
624        self
625    }
626}
627
628fn default_max_attempts() -> u32 {
629    30
630}
631
632fn default_max_tracked_keys() -> usize {
633    10_000
634}
635
636fn default_idle_eviction() -> Duration {
637    Duration::from_mins(15)
638}
639
640/// Authentication configuration.
641#[derive(Debug, Clone, Default, Deserialize)]
642#[non_exhaustive]
643pub struct AuthConfig {
644    /// Master switch - when false, all requests are allowed through.
645    #[serde(default)]
646    pub enabled: bool,
647    /// Bearer token API keys.
648    #[serde(default)]
649    pub api_keys: Vec<ApiKeyEntry>,
650    /// mTLS client certificate authentication.
651    pub mtls: Option<MtlsConfig>,
652    /// Rate limiting for auth attempts.
653    pub rate_limit: Option<RateLimitConfig>,
654    /// OAuth 2.1 JWT bearer token authentication.
655    #[cfg(feature = "oauth")]
656    pub oauth: Option<crate::oauth::OAuthConfig>,
657}
658
659impl AuthConfig {
660    /// Create an enabled auth config with the given API keys.
661    #[must_use]
662    pub fn with_keys(keys: Vec<ApiKeyEntry>) -> Self {
663        Self {
664            enabled: true,
665            api_keys: keys,
666            mtls: None,
667            rate_limit: None,
668            #[cfg(feature = "oauth")]
669            oauth: None,
670        }
671    }
672
673    /// Set rate limiting on this auth config.
674    #[must_use]
675    pub fn with_rate_limit(mut self, rate_limit: RateLimitConfig) -> Self {
676        self.rate_limit = Some(rate_limit);
677        self
678    }
679}
680
681/// Summary of a single API key suitable for admin endpoints.
682///
683/// Intentionally omits the Argon2id hash - only metadata is exposed.
684#[derive(Debug, Clone, serde::Serialize)]
685#[non_exhaustive]
686pub struct ApiKeySummary {
687    /// Human-readable key label.
688    pub name: String,
689    /// RBAC role granted when this key authenticates.
690    pub role: String,
691    /// Optional RFC 3339 expiry timestamp. Serialized as a canonical
692    /// RFC 3339 string so the admin-endpoint wire format is preserved.
693    pub expires_at: Option<RfcTimestamp>,
694}
695
696/// Snapshot of the enabled authentication methods for admin endpoints.
697#[derive(Debug, Clone, serde::Serialize)]
698#[allow(
699    clippy::struct_excessive_bools,
700    reason = "this is a flat summary of independent auth-method booleans"
701)]
702#[non_exhaustive]
703pub struct AuthConfigSummary {
704    /// Master enabled flag from config.
705    pub enabled: bool,
706    /// Whether API-key bearer auth is configured.
707    pub bearer: bool,
708    /// Whether mTLS client auth is configured.
709    pub mtls: bool,
710    /// Whether OAuth JWT validation is configured.
711    pub oauth: bool,
712    /// Current API-key list (no hashes).
713    pub api_keys: Vec<ApiKeySummary>,
714}
715
716impl AuthConfig {
717    /// Produce a hash-free summary of the auth config for admin endpoints.
718    #[must_use]
719    pub fn summary(&self) -> AuthConfigSummary {
720        AuthConfigSummary {
721            enabled: self.enabled,
722            bearer: !self.api_keys.is_empty(),
723            mtls: self.mtls.is_some(),
724            #[cfg(feature = "oauth")]
725            oauth: self.oauth.is_some(),
726            #[cfg(not(feature = "oauth"))]
727            oauth: false,
728            api_keys: self
729                .api_keys
730                .iter()
731                .map(|k| ApiKeySummary {
732                    name: k.name.clone(),
733                    role: k.role.clone(),
734                    expires_at: k.expires_at,
735                })
736                .collect(),
737        }
738    }
739}
740
741/// Keyed rate limiter type (per source IP). Memory-bounded by
742/// [`RateLimitConfig::max_tracked_keys`] to defend against IP-spray `DoS`.
743pub(crate) type KeyedLimiter = BoundedKeyedLimiter<IpAddr>;
744
745/// Connection info for TLS connections, carrying the peer socket address
746/// and (when mTLS is configured) the verified client identity extracted
747/// from the peer certificate during the TLS handshake.
748///
749/// Defined as a local type so we can implement axum's `Connected` trait
750/// for our custom `TlsListener` without orphan rule issues. The `identity`
751/// field travels with the connection itself (via the wrapping IO type),
752/// so there is no shared map to race against, no port-reuse aliasing, and
753/// no eviction policy to maintain.
754#[derive(Clone, Debug)]
755#[non_exhaustive]
756pub(crate) struct TlsConnInfo {
757    /// Remote peer socket address.
758    pub addr: SocketAddr,
759    /// Verified mTLS client identity, if a client certificate was presented
760    /// and successfully extracted during the TLS handshake.
761    pub identity: Option<AuthIdentity>,
762}
763
764impl TlsConnInfo {
765    /// Construct a new [`TlsConnInfo`].
766    #[must_use]
767    pub(crate) const fn new(addr: SocketAddr, identity: Option<AuthIdentity>) -> Self {
768        Self { addr, identity }
769    }
770}
771
772/// Shared state for the auth middleware.
773///
774/// `api_keys` uses [`ArcSwap`] so the SIGHUP handler can atomically
775/// swap in a new key list without blocking in-flight requests.
776#[allow(
777    missing_debug_implementations,
778    reason = "contains governor RateLimiter and JwksCache without Debug impls"
779)]
780#[non_exhaustive]
781pub(crate) struct AuthState {
782    /// Active set of API keys (hot-swappable).
783    pub api_keys: ArcSwap<Vec<ApiKeyEntry>>,
784    /// Optional per-IP post-failure rate limiter (consulted *after* auth fails).
785    pub rate_limiter: Option<Arc<KeyedLimiter>>,
786    /// Optional per-IP pre-auth abuse gate (consulted *before* password-hash work).
787    /// mTLS-authenticated connections bypass this gate.
788    pub pre_auth_limiter: Option<Arc<KeyedLimiter>>,
789    #[cfg(feature = "oauth")]
790    /// Optional JWKS cache for OAuth JWT validation.
791    pub jwks_cache: Option<Arc<crate::oauth::JwksCache>>,
792    /// Tracks identity names that have already been logged at INFO level.
793    /// Subsequent auths for the same identity are logged at DEBUG.
794    pub seen_identities: Mutex<HashSet<String>>,
795    /// Lightweight in-memory auth success/failure counters for diagnostics.
796    pub counters: AuthCounters,
797}
798
799impl AuthState {
800    /// Atomically replace the API key list (lock-free, wait-free).
801    ///
802    /// New requests immediately see the updated keys.
803    /// In-flight requests that already loaded the old list finish
804    /// using it -- no torn reads.
805    pub(crate) fn reload_keys(&self, keys: Vec<ApiKeyEntry>) {
806        let count = keys.len();
807        self.api_keys.store(Arc::new(keys));
808        tracing::info!(keys = count, "API keys reloaded");
809    }
810
811    /// Snapshot auth counters for diagnostics and tests.
812    #[must_use]
813    pub(crate) fn counters_snapshot(&self) -> AuthCountersSnapshot {
814        self.counters.snapshot()
815    }
816
817    /// Produce the admin-endpoint list of API keys (metadata only, no hashes).
818    #[must_use]
819    pub(crate) fn api_key_summaries(&self) -> Vec<ApiKeySummary> {
820        self.api_keys
821            .load()
822            .iter()
823            .map(|k| ApiKeySummary {
824                name: k.name.clone(),
825                role: k.role.clone(),
826                expires_at: k.expires_at,
827            })
828            .collect()
829    }
830
831    /// Log auth success: INFO on first occurrence per identity, DEBUG after.
832    fn log_auth(&self, id: &AuthIdentity, method: &str) {
833        self.counters.record_success(id.method);
834        let first = self
835            .seen_identities
836            .lock()
837            .unwrap_or_else(std::sync::PoisonError::into_inner)
838            .insert(id.name.clone());
839        if first {
840            tracing::info!(name = %id.name, role = %id.role, "{method} authenticated");
841        } else {
842            tracing::debug!(name = %id.name, role = %id.role, "{method} authenticated");
843        }
844    }
845}
846
847/// Default auth rate limit: 30 attempts per minute per source IP.
848// SAFETY: unwrap() is safe - literal 30 is provably non-zero (const-evaluated).
849const DEFAULT_AUTH_RATE: NonZeroU32 = NonZeroU32::new(30).unwrap();
850
851/// Create a post-failure rate limiter from config.
852#[must_use]
853pub(crate) fn build_rate_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
854    let quota = governor::Quota::per_minute(
855        NonZeroU32::new(config.max_attempts_per_minute).unwrap_or(DEFAULT_AUTH_RATE),
856    );
857    Arc::new(BoundedKeyedLimiter::new(
858        quota,
859        config.max_tracked_keys,
860        config.idle_eviction,
861    ))
862}
863
864/// Create a pre-auth abuse-gate rate limiter from config.
865///
866/// Quota: `pre_auth_max_per_minute` if set, otherwise
867/// `max_attempts_per_minute * 10` (capped at `u32::MAX`). The 10× factor
868/// keeps the gate generous enough for honest retries while still bounding
869/// attacker CPU on Argon2 verification.
870#[must_use]
871pub(crate) fn build_pre_auth_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
872    let resolved = config.pre_auth_max_per_minute.unwrap_or_else(|| {
873        config
874            .max_attempts_per_minute
875            .saturating_mul(PRE_AUTH_DEFAULT_MULTIPLIER)
876    });
877    let quota =
878        governor::Quota::per_minute(NonZeroU32::new(resolved).unwrap_or(DEFAULT_PRE_AUTH_RATE));
879    Arc::new(BoundedKeyedLimiter::new(
880        quota,
881        config.max_tracked_keys,
882        config.idle_eviction,
883    ))
884}
885
886/// Default multiplier applied to `max_attempts_per_minute` when the operator
887/// does not set `pre_auth_max_per_minute` explicitly.
888const PRE_AUTH_DEFAULT_MULTIPLIER: u32 = 10;
889
890/// Default pre-auth abuse-gate rate (used only if both the configured value
891/// and the multiplied fallback are zero, which `NonZeroU32::new` rejects).
892// SAFETY: unwrap() is safe - literal 300 is provably non-zero (const-evaluated).
893const DEFAULT_PRE_AUTH_RATE: NonZeroU32 = NonZeroU32::new(300).unwrap();
894
895/// Parse an mTLS client certificate and extract an `AuthIdentity`.
896///
897/// Reads the Subject CN as the identity name. Falls back to the first
898/// DNS SAN if CN is absent. The role is taken from the `MtlsConfig`.
899#[must_use]
900pub fn extract_mtls_identity(cert_der: &[u8], default_role: &str) -> Option<AuthIdentity> {
901    let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
902
903    // Try CN from Subject first.
904    let cn = cert
905        .subject()
906        .iter_common_name()
907        .next()
908        .and_then(|attr| attr.as_str().ok())
909        .map(String::from);
910
911    // Fall back to first DNS SAN.
912    let name = cn.or_else(|| {
913        cert.subject_alternative_name()
914            .ok()
915            .flatten()
916            .and_then(|san| {
917                #[allow(clippy::wildcard_enum_match_arm)]
918                san.value.general_names.iter().find_map(|gn| match gn {
919                    GeneralName::DNSName(dns) => Some((*dns).to_owned()),
920                    _ => None,
921                })
922            })
923    })?;
924
925    // Reject identities with characters unsafe for logging and RBAC matching.
926    if !name
927        .chars()
928        .all(|c| c.is_alphanumeric() || matches!(c, '-' | '.' | '_' | '@'))
929    {
930        tracing::warn!(cn = %name, "mTLS identity rejected: invalid characters in CN/SAN");
931        return None;
932    }
933
934    Some(AuthIdentity {
935        name,
936        role: default_role.to_owned(),
937        method: AuthMethod::MtlsCertificate,
938        raw_token: None,
939        sub: None,
940    })
941}
942
943/// Extract the bearer token from an `Authorization` header value.
944///
945/// Implements RFC 7235 §2.1: the auth-scheme token is **case-insensitive**.
946/// `Bearer`, `bearer`, `BEARER`, and `BeArEr` all parse equivalently. Any
947/// leading whitespace between the scheme and the token is trimmed (per
948/// RFC 7235 the separator is one or more SP characters; we accept the
949/// common single-space form plus tolerate extras).
950///
951/// Returns `None` if the header value:
952/// - does not contain a space (no scheme/credentials boundary), or
953/// - uses a scheme other than `Bearer` (case-insensitively).
954///
955/// The caller is responsible for token-level validation (length, charset,
956/// signature, etc.); this helper only handles the scheme prefix.
957fn extract_bearer(value: &str) -> Option<&str> {
958    let (scheme, rest) = value.split_once(' ')?;
959    if scheme.eq_ignore_ascii_case("Bearer") {
960        let token = rest.trim_start_matches(' ');
961        if token.is_empty() { None } else { Some(token) }
962    } else {
963        None
964    }
965}
966
967/// Verify a bearer token against configured API keys.
968///
969/// Argon2id verification is CPU-intensive, so this should be called via
970/// `spawn_blocking`. Returns the matching identity if the token is valid.
971///
972/// # Timing-side-channel resistance
973///
974/// Always performs **exactly one Argon2id verification per configured key**,
975/// regardless of:
976///
977/// * which slot (if any) matches the presented token, or
978/// * whether a key has expired.
979///
980/// Expired and post-match slots are verified against an internal dummy PHC hash,
981/// a fixed Argon2id PHC string with the same cost parameters as the real
982/// hashes. This bounds the timing observable to "one Argon2 per configured
983/// key" regardless of which (if any) slot held the matching credential,
984/// closing the first-match latency oracle (CWE-208) and the expired-slot
985/// timing leak.
986///
987/// `subtle::ConstantTimeEq` is used to fold per-slot match bits into the
988/// final result so the compiler cannot reintroduce a data-dependent branch.
989///
990/// # Panics
991///
992/// Panics if the internal dummy PHC hash cannot be parsed as an Argon2id PHC string.
993/// This is impossible by construction: the static is generated by
994/// [`argon2::Argon2::hash_password`] which always emits a valid PHC string.
995#[must_use]
996pub fn verify_bearer_token(token: &str, keys: &[ApiKeyEntry]) -> Option<AuthIdentity> {
997    use subtle::ConstantTimeEq as _;
998
999    let now = chrono::Utc::now();
1000    let dummy_hash = PasswordHash::new(&DUMMY_PHC_HASH)
1001        .expect("DUMMY_PHC_HASH is a valid Argon2id PHC string by construction");
1002
1003    let mut matched_index: usize = usize::MAX;
1004    let mut any_match: u8 = 0;
1005
1006    for (idx, key) in keys.iter().enumerate() {
1007        let expired = key.expires_at.is_some_and(|exp| exp.as_datetime() < &now);
1008
1009        let real_hash = PasswordHash::new(&key.hash);
1010        let verify_against = match (&real_hash, expired, any_match) {
1011            (Ok(h), false, 0) => h,
1012            _ => &dummy_hash,
1013        };
1014
1015        let slot_ok = u8::from(
1016            Argon2::default()
1017                .verify_password(token.as_bytes(), verify_against)
1018                .is_ok(),
1019        );
1020
1021        let real_match = slot_ok & u8::from(!expired) & u8::from(real_hash.is_ok());
1022        let first_real_match = real_match & (1 - any_match);
1023        if first_real_match.ct_eq(&1).into() {
1024            matched_index = idx;
1025        }
1026        any_match |= real_match;
1027    }
1028
1029    if any_match == 0 {
1030        return None;
1031    }
1032    let key = keys.get(matched_index)?;
1033    Some(AuthIdentity {
1034        name: key.name.clone(),
1035        role: key.role.clone(),
1036        method: AuthMethod::BearerToken,
1037        raw_token: None,
1038        sub: None,
1039    })
1040}
1041
1042/// Fixed Argon2id PHC hash used as a constant-time placeholder when an
1043/// API-key slot is expired, malformed, or follows the matching slot.
1044///
1045/// Generated once on first access using the same default Argon2 cost
1046/// parameters as live verifications, so the dummy verify takes
1047/// indistinguishable wall time from a real one. The plaintext
1048/// (`"rmcp-server-kit-dummy"`) is unrelated to any real credential.
1049static DUMMY_PHC_HASH: LazyLock<String> = LazyLock::new(|| {
1050    let salt = SaltString::generate(&mut argon2::password_hash::rand_core::OsRng);
1051    Argon2::default()
1052        .hash_password(b"rmcp-server-kit-dummy", &salt)
1053        .expect("Argon2 default params hash a fixed plaintext")
1054        .to_string()
1055});
1056
1057/// Generate a new API key: 256-bit random token + Argon2id hash.
1058///
1059/// Returns `(plaintext_token, argon2id_hash_phc_string)`.
1060/// The plaintext is shown once to the user and never stored.
1061///
1062/// # Errors
1063///
1064/// Returns an error if salt encoding or Argon2id hashing fails
1065/// (should not happen with valid inputs, but we avoid panicking).
1066pub fn generate_api_key() -> Result<(String, String), McpxError> {
1067    let mut token_bytes = [0u8; 32];
1068    rand::fill(&mut token_bytes);
1069    let token = URL_SAFE_NO_PAD.encode(token_bytes);
1070
1071    // Generate 16 random bytes for salt, encode as base64 for SaltString.
1072    let mut salt_bytes = [0u8; 16];
1073    rand::fill(&mut salt_bytes);
1074    let salt = SaltString::encode_b64(&salt_bytes)
1075        .map_err(|e| McpxError::Auth(format!("salt encoding failed: {e}")))?;
1076    let hash = Argon2::default()
1077        .hash_password(token.as_bytes(), &salt)
1078        .map_err(|e| McpxError::Auth(format!("argon2id hashing failed: {e}")))?
1079        .to_string();
1080
1081    Ok((token, hash))
1082}
1083
1084fn build_www_authenticate_value(
1085    advertise_resource_metadata: bool,
1086    failure: AuthFailureClass,
1087) -> String {
1088    let (error, error_description) = failure.bearer_error();
1089    if advertise_resource_metadata {
1090        return format!(
1091            "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\", error=\"{error}\", error_description=\"{error_description}\""
1092        );
1093    }
1094    format!("Bearer error=\"{error}\", error_description=\"{error_description}\"")
1095}
1096
1097fn auth_method_label(method: AuthMethod) -> &'static str {
1098    match method {
1099        AuthMethod::MtlsCertificate => "mTLS",
1100        AuthMethod::BearerToken => "bearer token",
1101        AuthMethod::OAuthJwt => "OAuth JWT",
1102    }
1103}
1104
1105#[cfg_attr(not(feature = "oauth"), allow(unused_variables))]
1106fn unauthorized_response(state: &AuthState, failure_class: AuthFailureClass) -> Response {
1107    #[cfg(feature = "oauth")]
1108    let advertise_resource_metadata = state.jwks_cache.is_some();
1109    #[cfg(not(feature = "oauth"))]
1110    let advertise_resource_metadata = false;
1111
1112    let challenge = build_www_authenticate_value(advertise_resource_metadata, failure_class);
1113    (
1114        axum::http::StatusCode::UNAUTHORIZED,
1115        [(header::WWW_AUTHENTICATE, challenge)],
1116        failure_class.response_body(),
1117    )
1118        .into_response()
1119}
1120
1121async fn authenticate_bearer_identity(
1122    state: &AuthState,
1123    token: &str,
1124) -> Result<AuthIdentity, AuthFailureClass> {
1125    let mut failure_class = AuthFailureClass::MissingCredential;
1126
1127    #[cfg(feature = "oauth")]
1128    if let Some(ref cache) = state.jwks_cache
1129        && crate::oauth::looks_like_jwt(token)
1130    {
1131        match cache.validate_token_with_reason(token).await {
1132            Ok(mut id) => {
1133                id.raw_token = Some(SecretString::from(token.to_owned()));
1134                return Ok(id);
1135            }
1136            Err(crate::oauth::JwtValidationFailure::Expired) => {
1137                failure_class = AuthFailureClass::ExpiredCredential;
1138            }
1139            Err(crate::oauth::JwtValidationFailure::Invalid) => {
1140                failure_class = AuthFailureClass::InvalidCredential;
1141            }
1142        }
1143    }
1144
1145    let token = token.to_owned();
1146    let keys = state.api_keys.load_full(); // Arc clone, lock-free
1147
1148    // Argon2id is CPU-bound - offload to blocking thread pool.
1149    let identity = tokio::task::spawn_blocking(move || verify_bearer_token(&token, &keys))
1150        .await
1151        .ok()
1152        .flatten();
1153
1154    if let Some(id) = identity {
1155        return Ok(id);
1156    }
1157
1158    if failure_class == AuthFailureClass::MissingCredential {
1159        failure_class = AuthFailureClass::InvalidCredential;
1160    }
1161
1162    Err(failure_class)
1163}
1164
1165/// Consult the pre-auth abuse gate for the given peer.
1166///
1167/// Returns `Some(response)` if the request should be rejected (limiter
1168/// configured AND quota exhausted for this source IP). Returns `None`
1169/// otherwise (limiter absent, peer address unknown, or quota available),
1170/// in which case the caller should proceed with credential verification.
1171///
1172/// Side effects on rejection: increments the `pre_auth_gate` failure
1173/// counter and emits a warn-level log. mTLS-authenticated requests must
1174/// be admitted by the caller *before* invoking this helper.
1175fn pre_auth_gate(state: &AuthState, peer_addr: Option<SocketAddr>) -> Option<Response> {
1176    let limiter = state.pre_auth_limiter.as_ref()?;
1177    let addr = peer_addr?;
1178    if limiter.check_key(&addr.ip()).is_ok() {
1179        return None;
1180    }
1181    state.counters.record_failure(AuthFailureClass::PreAuthGate);
1182    tracing::warn!(
1183        ip = %addr.ip(),
1184        "auth rate limited by pre-auth gate (request rejected before credential verification)"
1185    );
1186    Some(
1187        McpxError::RateLimited("too many unauthenticated requests from this source".into())
1188            .into_response(),
1189    )
1190}
1191
1192/// Axum middleware that enforces authentication.
1193///
1194/// Tries authentication methods in priority order:
1195/// 1. mTLS client certificate identity (populated by TLS acceptor)
1196/// 2. Bearer token from `Authorization` header
1197///
1198/// Failed authentication attempts are rate-limited per source IP.
1199/// Successful authentications do not consume rate limit budget.
1200pub(crate) async fn auth_middleware(
1201    state: Arc<AuthState>,
1202    req: Request<Body>,
1203    next: Next,
1204) -> Response {
1205    // Extract peer address (and any mTLS identity) from ConnectInfo.
1206    // Plain TCP: ConnectInfo<SocketAddr>. TLS / mTLS: ConnectInfo<TlsConnInfo>,
1207    // which carries the verified identity directly on the connection — no
1208    // shared map, no port-reuse aliasing.
1209    let tls_info = req.extensions().get::<ConnectInfo<TlsConnInfo>>().cloned();
1210    let peer_addr = req
1211        .extensions()
1212        .get::<ConnectInfo<SocketAddr>>()
1213        .map(|ci| ci.0)
1214        .or_else(|| tls_info.as_ref().map(|ci| ci.0.addr));
1215
1216    // 1. Try mTLS identity (extracted by the TLS acceptor during handshake
1217    //    and attached to the connection itself).
1218    //
1219    //    mTLS connections bypass the pre-auth abuse gate below: the TLS
1220    //    handshake already performed expensive crypto with a verified peer,
1221    //    so we trust them not to be a CPU-spray attacker.
1222    if let Some(id) = tls_info.and_then(|ci| ci.0.identity) {
1223        state.log_auth(&id, "mTLS");
1224        let mut req = req;
1225        req.extensions_mut().insert(id);
1226        return next.run(req).await;
1227    }
1228
1229    // 2. Pre-auth abuse gate: rejects CPU-spray attacks BEFORE the Argon2id
1230    //    verification path runs. Keyed by source IP. mTLS connections (above)
1231    //    are exempt; this gate only protects the bearer/JWT verification path.
1232    if let Some(blocked) = pre_auth_gate(&state, peer_addr) {
1233        return blocked;
1234    }
1235
1236    let failure_class = if let Some(value) = req.headers().get(header::AUTHORIZATION) {
1237        match value.to_str().ok().and_then(extract_bearer) {
1238            Some(token) => match authenticate_bearer_identity(&state, token).await {
1239                Ok(id) => {
1240                    state.log_auth(&id, auth_method_label(id.method));
1241                    let mut req = req;
1242                    req.extensions_mut().insert(id);
1243                    return next.run(req).await;
1244                }
1245                Err(class) => class,
1246            },
1247            None => AuthFailureClass::InvalidCredential,
1248        }
1249    } else {
1250        AuthFailureClass::MissingCredential
1251    };
1252
1253    tracing::warn!(failure_class = %failure_class.as_str(), "auth failed");
1254
1255    // Rate limit check (applied after auth failure only).
1256    // Successful authentications do not consume rate limit budget.
1257    if let (Some(limiter), Some(addr)) = (&state.rate_limiter, peer_addr)
1258        && limiter.check_key(&addr.ip()).is_err()
1259    {
1260        state.counters.record_failure(AuthFailureClass::RateLimited);
1261        tracing::warn!(ip = %addr.ip(), "auth rate limited after repeated failures");
1262        return McpxError::RateLimited("too many failed authentication attempts".into())
1263            .into_response();
1264    }
1265
1266    state.counters.record_failure(failure_class);
1267    unauthorized_response(&state, failure_class)
1268}
1269
1270#[cfg(test)]
1271mod tests {
1272    use super::*;
1273
1274    #[test]
1275    fn generate_and_verify_api_key() {
1276        let (token, hash) = generate_api_key().unwrap();
1277
1278        // Token is 43 chars (256-bit base64url, no padding)
1279        assert_eq!(token.len(), 43);
1280
1281        // Hash is a valid PHC string
1282        assert!(hash.starts_with("$argon2id$"));
1283
1284        // Verification succeeds with correct token
1285        let keys = vec![ApiKeyEntry {
1286            name: "test".into(),
1287            hash,
1288            role: "viewer".into(),
1289            expires_at: None,
1290        }];
1291        let id = verify_bearer_token(&token, &keys);
1292        assert!(id.is_some());
1293        let id = id.unwrap();
1294        assert_eq!(id.name, "test");
1295        assert_eq!(id.role, "viewer");
1296        assert_eq!(id.method, AuthMethod::BearerToken);
1297    }
1298
1299    #[test]
1300    fn wrong_token_rejected() {
1301        let (_token, hash) = generate_api_key().unwrap();
1302        let keys = vec![ApiKeyEntry {
1303            name: "test".into(),
1304            hash,
1305            role: "viewer".into(),
1306            expires_at: None,
1307        }];
1308        assert!(verify_bearer_token("wrong-token", &keys).is_none());
1309    }
1310
1311    #[test]
1312    fn expired_key_rejected() {
1313        let (token, hash) = generate_api_key().unwrap();
1314        let keys = vec![ApiKeyEntry {
1315            name: "test".into(),
1316            hash,
1317            role: "viewer".into(),
1318            expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1319        }];
1320        assert!(verify_bearer_token(&token, &keys).is_none());
1321    }
1322
1323    #[test]
1324    fn match_in_last_slot_still_authenticates() {
1325        let (token, hash) = generate_api_key().unwrap();
1326        let (_other_token, other_hash) = generate_api_key().unwrap();
1327        let keys = vec![
1328            ApiKeyEntry {
1329                name: "first".into(),
1330                hash: other_hash.clone(),
1331                role: "viewer".into(),
1332                expires_at: None,
1333            },
1334            ApiKeyEntry {
1335                name: "second".into(),
1336                hash: other_hash,
1337                role: "viewer".into(),
1338                expires_at: None,
1339            },
1340            ApiKeyEntry {
1341                name: "match".into(),
1342                hash,
1343                role: "ops".into(),
1344                expires_at: None,
1345            },
1346        ];
1347        let id = verify_bearer_token(&token, &keys).expect("last-slot match must authenticate");
1348        assert_eq!(id.name, "match");
1349        assert_eq!(id.role, "ops");
1350    }
1351
1352    #[test]
1353    fn expired_slot_before_valid_match_does_not_short_circuit() {
1354        let (token, hash) = generate_api_key().unwrap();
1355        let (_, other_hash) = generate_api_key().unwrap();
1356        let keys = vec![
1357            ApiKeyEntry {
1358                name: "expired".into(),
1359                hash: other_hash,
1360                role: "viewer".into(),
1361                expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1362            },
1363            ApiKeyEntry {
1364                name: "valid".into(),
1365                hash,
1366                role: "ops".into(),
1367                expires_at: None,
1368            },
1369        ];
1370        let id = verify_bearer_token(&token, &keys)
1371            .expect("valid slot following an expired slot must authenticate");
1372        assert_eq!(id.name, "valid");
1373    }
1374
1375    #[test]
1376    fn malformed_hash_slot_does_not_short_circuit() {
1377        let (token, hash) = generate_api_key().unwrap();
1378        let keys = vec![
1379            ApiKeyEntry {
1380                name: "broken".into(),
1381                hash: "this-is-not-a-phc-string".into(),
1382                role: "viewer".into(),
1383                expires_at: None,
1384            },
1385            ApiKeyEntry {
1386                name: "valid".into(),
1387                hash,
1388                role: "ops".into(),
1389                expires_at: None,
1390            },
1391        ];
1392        let id = verify_bearer_token(&token, &keys)
1393            .expect("valid slot following a malformed-hash slot must authenticate");
1394        assert_eq!(id.name, "valid");
1395    }
1396
1397    // Regression tests for H3 (api_key_expires_at_fail_open).
1398    //
1399    // Prior to 1.6.0 the runtime expiry check used a chained
1400    // `if let Some(_) && let Ok(exp) = parse(_) && exp < now` which
1401    // silently fell through on parse error, letting a key with
1402    // `expires_at = "not-a-date"` authenticate forever. These tests
1403    // pin the type-system fix: malformed RFC 3339 is rejected at
1404    // deserialization time (no `RfcTimestamp` can ever be malformed),
1405    // and the runtime check is a pure comparison with no parse path.
1406
1407    #[test]
1408    fn rfc_timestamp_parse_rejects_malformed() {
1409        for bad in [
1410            "not-a-date",
1411            "",
1412            "2025-13-01T00:00:00Z", // month 13
1413            "2025-01-32T00:00:00Z", // day 32
1414            "2025-01-01T00:00:00",  // missing offset
1415            "01/01/2025",           // wrong format
1416            "2025-01-01T25:00:00Z", // hour 25
1417        ] {
1418            assert!(
1419                RfcTimestamp::parse(bad).is_err(),
1420                "RfcTimestamp::parse must reject {bad:?}"
1421            );
1422        }
1423    }
1424
1425    #[test]
1426    fn rfc_timestamp_parse_accepts_valid() {
1427        for good in [
1428            "2025-01-01T00:00:00Z",
1429            "2025-01-01T00:00:00+00:00",
1430            "2025-12-31T23:59:59-08:00",
1431            "2099-01-01T00:00:00.123456789Z",
1432        ] {
1433            assert!(
1434                RfcTimestamp::parse(good).is_ok(),
1435                "RfcTimestamp::parse must accept {good:?}"
1436            );
1437        }
1438    }
1439
1440    #[test]
1441    fn api_key_entry_deserialize_rejects_malformed_expires_at() {
1442        // TOML with a malformed expires_at must fail to deserialize.
1443        // This is the load-time defense: a typo in auth.toml aborts
1444        // config load with a clear serde error, instead of producing
1445        // a key that authenticates forever (the H3 fail-open).
1446        let toml = r#"
1447            name = "bad-key"
1448            hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1449            role = "viewer"
1450            expires_at = "not-a-date"
1451        "#;
1452        let result: Result<ApiKeyEntry, _> = toml::from_str(toml);
1453        assert!(
1454            result.is_err(),
1455            "deserialization must reject malformed expires_at"
1456        );
1457    }
1458
1459    #[test]
1460    fn api_key_entry_deserialize_accepts_valid_expires_at() {
1461        let toml = r#"
1462            name = "good-key"
1463            hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1464            role = "viewer"
1465            expires_at = "2099-01-01T00:00:00Z"
1466        "#;
1467        let entry: ApiKeyEntry = toml::from_str(toml).expect("valid RFC 3339 must deserialize");
1468        assert!(entry.expires_at.is_some());
1469    }
1470
1471    #[test]
1472    fn api_key_entry_deserialize_accepts_missing_expires_at() {
1473        // Omitting expires_at must continue to mean "no expiry"; this
1474        // is the documented contract and must survive the H3 fix.
1475        let toml = r#"
1476            name = "eternal-key"
1477            hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1478            role = "viewer"
1479        "#;
1480        let entry: ApiKeyEntry = toml::from_str(toml).expect("missing expires_at must deserialize");
1481        assert!(entry.expires_at.is_none());
1482    }
1483
1484    #[test]
1485    fn try_with_expiry_rejects_malformed() {
1486        let entry = ApiKeyEntry::new("k", "hash", "viewer");
1487        assert!(entry.try_with_expiry("not-a-date").is_err());
1488    }
1489
1490    #[test]
1491    fn try_with_expiry_accepts_valid() {
1492        let entry = ApiKeyEntry::new("k", "hash", "viewer")
1493            .try_with_expiry("2099-01-01T00:00:00Z")
1494            .expect("valid RFC 3339 must be accepted");
1495        assert!(entry.expires_at.is_some());
1496    }
1497
1498    #[test]
1499    fn api_key_summary_serializes_expires_at_as_rfc3339() {
1500        // The admin endpoint wire format is `{"expires_at": "RFC 3339 str"}`.
1501        // Pinning this prevents an accidental serialization-format change
1502        // (e.g. chrono's debug form, a Unix timestamp) that would silently
1503        // break operator tooling that parses these payloads.
1504        let summary = ApiKeySummary {
1505            name: "k".into(),
1506            role: "viewer".into(),
1507            expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
1508        };
1509        let json = serde_json::to_string(&summary).unwrap();
1510        assert!(
1511            json.contains(r#""expires_at":"2030-01-01T00:00:00+00:00""#),
1512            "wire format regressed: {json}"
1513        );
1514    }
1515
1516    #[test]
1517    fn future_expiry_accepted() {
1518        let (token, hash) = generate_api_key().unwrap();
1519        let keys = vec![ApiKeyEntry {
1520            name: "test".into(),
1521            hash,
1522            role: "viewer".into(),
1523            expires_at: Some(RfcTimestamp::parse("2099-01-01T00:00:00Z").unwrap()),
1524        }];
1525        assert!(verify_bearer_token(&token, &keys).is_some());
1526    }
1527
1528    #[test]
1529    fn multiple_keys_first_match_wins() {
1530        let (token, hash) = generate_api_key().unwrap();
1531        let keys = vec![
1532            ApiKeyEntry {
1533                name: "wrong".into(),
1534                hash: "$argon2id$v=19$m=19456,t=2,p=1$invalid$invalid".into(),
1535                role: "ops".into(),
1536                expires_at: None,
1537            },
1538            ApiKeyEntry {
1539                name: "correct".into(),
1540                hash,
1541                role: "deploy".into(),
1542                expires_at: None,
1543            },
1544        ];
1545        let id = verify_bearer_token(&token, &keys).unwrap();
1546        assert_eq!(id.name, "correct");
1547        assert_eq!(id.role, "deploy");
1548    }
1549
1550    #[test]
1551    fn rate_limiter_allows_within_quota() {
1552        let config = RateLimitConfig {
1553            max_attempts_per_minute: 5,
1554            pre_auth_max_per_minute: None,
1555            ..Default::default()
1556        };
1557        let limiter = build_rate_limiter(&config);
1558        let ip: IpAddr = "10.0.0.1".parse().unwrap();
1559
1560        // First 5 should succeed.
1561        for _ in 0..5 {
1562            assert!(limiter.check_key(&ip).is_ok());
1563        }
1564        // 6th should fail.
1565        assert!(limiter.check_key(&ip).is_err());
1566    }
1567
1568    #[test]
1569    fn rate_limiter_separate_ips() {
1570        let config = RateLimitConfig {
1571            max_attempts_per_minute: 2,
1572            pre_auth_max_per_minute: None,
1573            ..Default::default()
1574        };
1575        let limiter = build_rate_limiter(&config);
1576        let ip1: IpAddr = "10.0.0.1".parse().unwrap();
1577        let ip2: IpAddr = "10.0.0.2".parse().unwrap();
1578
1579        // Exhaust ip1's quota.
1580        assert!(limiter.check_key(&ip1).is_ok());
1581        assert!(limiter.check_key(&ip1).is_ok());
1582        assert!(limiter.check_key(&ip1).is_err());
1583
1584        // ip2 should still have quota.
1585        assert!(limiter.check_key(&ip2).is_ok());
1586    }
1587
1588    #[test]
1589    fn extract_mtls_identity_from_cn() {
1590        // Generate a cert with explicit CN.
1591        let mut params = rcgen::CertificateParams::new(vec!["test-client.local".into()]).unwrap();
1592        params.distinguished_name = rcgen::DistinguishedName::new();
1593        params
1594            .distinguished_name
1595            .push(rcgen::DnType::CommonName, "test-client");
1596        let cert = params
1597            .self_signed(&rcgen::KeyPair::generate().unwrap())
1598            .unwrap();
1599        let der = cert.der();
1600
1601        let id = extract_mtls_identity(der, "ops").unwrap();
1602        assert_eq!(id.name, "test-client");
1603        assert_eq!(id.role, "ops");
1604        assert_eq!(id.method, AuthMethod::MtlsCertificate);
1605    }
1606
1607    #[test]
1608    fn extract_mtls_identity_falls_back_to_san() {
1609        // Cert with no CN but has a DNS SAN.
1610        let mut params =
1611            rcgen::CertificateParams::new(vec!["san-only.example.com".into()]).unwrap();
1612        params.distinguished_name = rcgen::DistinguishedName::new();
1613        // No CN set - should fall back to DNS SAN.
1614        let cert = params
1615            .self_signed(&rcgen::KeyPair::generate().unwrap())
1616            .unwrap();
1617        let der = cert.der();
1618
1619        let id = extract_mtls_identity(der, "viewer").unwrap();
1620        assert_eq!(id.name, "san-only.example.com");
1621        assert_eq!(id.role, "viewer");
1622    }
1623
1624    #[test]
1625    fn extract_mtls_identity_invalid_der() {
1626        assert!(extract_mtls_identity(b"not-a-cert", "viewer").is_none());
1627    }
1628
1629    // -- auth_middleware integration tests --
1630
1631    use axum::{
1632        body::Body,
1633        http::{Request, StatusCode},
1634    };
1635    use tower::ServiceExt as _;
1636
1637    fn auth_router(state: Arc<AuthState>) -> axum::Router {
1638        axum::Router::new()
1639            .route("/mcp", axum::routing::post(|| async { "ok" }))
1640            .layer(axum::middleware::from_fn(move |req, next| {
1641                let s = Arc::clone(&state);
1642                auth_middleware(s, req, next)
1643            }))
1644    }
1645
1646    fn test_auth_state(keys: Vec<ApiKeyEntry>) -> Arc<AuthState> {
1647        Arc::new(AuthState {
1648            api_keys: ArcSwap::new(Arc::new(keys)),
1649            rate_limiter: None,
1650            pre_auth_limiter: None,
1651            #[cfg(feature = "oauth")]
1652            jwks_cache: None,
1653            seen_identities: Mutex::new(HashSet::new()),
1654            counters: AuthCounters::default(),
1655        })
1656    }
1657
1658    #[tokio::test]
1659    async fn middleware_rejects_no_credentials() {
1660        let state = test_auth_state(vec![]);
1661        let app = auth_router(Arc::clone(&state));
1662        let req = Request::builder()
1663            .method(axum::http::Method::POST)
1664            .uri("/mcp")
1665            .body(Body::empty())
1666            .unwrap();
1667        let resp = app.oneshot(req).await.unwrap();
1668        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1669        let challenge = resp
1670            .headers()
1671            .get(header::WWW_AUTHENTICATE)
1672            .unwrap()
1673            .to_str()
1674            .unwrap();
1675        assert!(challenge.contains("error=\"invalid_request\""));
1676
1677        let counters = state.counters_snapshot();
1678        assert_eq!(counters.failure_missing_credential, 1);
1679    }
1680
1681    #[tokio::test]
1682    async fn middleware_accepts_valid_bearer() {
1683        let (token, hash) = generate_api_key().unwrap();
1684        let keys = vec![ApiKeyEntry {
1685            name: "test-key".into(),
1686            hash,
1687            role: "ops".into(),
1688            expires_at: None,
1689        }];
1690        let state = test_auth_state(keys);
1691        let app = auth_router(Arc::clone(&state));
1692        let req = Request::builder()
1693            .method(axum::http::Method::POST)
1694            .uri("/mcp")
1695            .header("authorization", format!("Bearer {token}"))
1696            .body(Body::empty())
1697            .unwrap();
1698        let resp = app.oneshot(req).await.unwrap();
1699        assert_eq!(resp.status(), StatusCode::OK);
1700
1701        let counters = state.counters_snapshot();
1702        assert_eq!(counters.success_bearer, 1);
1703    }
1704
1705    #[tokio::test]
1706    async fn middleware_rejects_wrong_bearer() {
1707        let (_token, hash) = generate_api_key().unwrap();
1708        let keys = vec![ApiKeyEntry {
1709            name: "test-key".into(),
1710            hash,
1711            role: "ops".into(),
1712            expires_at: None,
1713        }];
1714        let state = test_auth_state(keys);
1715        let app = auth_router(Arc::clone(&state));
1716        let req = Request::builder()
1717            .method(axum::http::Method::POST)
1718            .uri("/mcp")
1719            .header("authorization", "Bearer wrong-token-here")
1720            .body(Body::empty())
1721            .unwrap();
1722        let resp = app.oneshot(req).await.unwrap();
1723        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1724        let challenge = resp
1725            .headers()
1726            .get(header::WWW_AUTHENTICATE)
1727            .unwrap()
1728            .to_str()
1729            .unwrap();
1730        assert!(challenge.contains("error=\"invalid_token\""));
1731
1732        let counters = state.counters_snapshot();
1733        assert_eq!(counters.failure_invalid_credential, 1);
1734    }
1735
1736    #[tokio::test]
1737    async fn middleware_rate_limits() {
1738        let state = Arc::new(AuthState {
1739            api_keys: ArcSwap::new(Arc::new(vec![])),
1740            rate_limiter: Some(build_rate_limiter(&RateLimitConfig {
1741                max_attempts_per_minute: 1,
1742                pre_auth_max_per_minute: None,
1743                ..Default::default()
1744            })),
1745            pre_auth_limiter: None,
1746            #[cfg(feature = "oauth")]
1747            jwks_cache: None,
1748            seen_identities: Mutex::new(HashSet::new()),
1749            counters: AuthCounters::default(),
1750        });
1751        let app = auth_router(state);
1752
1753        // First request: UNAUTHORIZED (no credentials, but not rate limited)
1754        let req = Request::builder()
1755            .method(axum::http::Method::POST)
1756            .uri("/mcp")
1757            .body(Body::empty())
1758            .unwrap();
1759        let resp = app.clone().oneshot(req).await.unwrap();
1760        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1761
1762        // Second request from same "IP" (no ConnectInfo in test, so peer_addr is None
1763        // and rate limiter won't fire). That's expected -- rate limiting requires
1764        // ConnectInfo which isn't available in unit tests without a real server.
1765        // This test verifies the middleware wiring doesn't panic.
1766    }
1767
1768    /// Verify that rate limit semantics: only failed auth attempts consume budget.
1769    ///
1770    /// This is a unit test of the limiter behavior. The middleware integration
1771    /// is that on auth failure, `check_key` is called; on auth success, it is NOT.
1772    /// Full e2e tests verify the middleware routing but require `ConnectInfo`.
1773    #[test]
1774    fn rate_limit_semantics_failed_only() {
1775        let config = RateLimitConfig {
1776            max_attempts_per_minute: 3,
1777            pre_auth_max_per_minute: None,
1778            ..Default::default()
1779        };
1780        let limiter = build_rate_limiter(&config);
1781        let ip: IpAddr = "192.168.1.100".parse().unwrap();
1782
1783        // Simulate: 3 failed attempts should exhaust quota.
1784        assert!(
1785            limiter.check_key(&ip).is_ok(),
1786            "failure 1 should be allowed"
1787        );
1788        assert!(
1789            limiter.check_key(&ip).is_ok(),
1790            "failure 2 should be allowed"
1791        );
1792        assert!(
1793            limiter.check_key(&ip).is_ok(),
1794            "failure 3 should be allowed"
1795        );
1796        assert!(
1797            limiter.check_key(&ip).is_err(),
1798            "failure 4 should be blocked"
1799        );
1800
1801        // In the actual middleware flow:
1802        // - Successful auth: verify_bearer_token returns Some, we return early
1803        //   WITHOUT calling check_key, so no budget consumed.
1804        // - Failed auth: verify_bearer_token returns None, we call check_key
1805        //   THEN return 401, so budget is consumed.
1806        //
1807        // This means N successful requests followed by M failed requests
1808        // will only count M toward the rate limit, not N+M.
1809    }
1810
1811    // -- pre-auth abuse gate (H-S1) --
1812
1813    /// The pre-auth gate must default to ~10x the post-failure quota so honest
1814    /// retry storms never trip it but a Argon2-spray attacker is throttled.
1815    #[test]
1816    fn pre_auth_default_multiplier_is_10x() {
1817        let config = RateLimitConfig {
1818            max_attempts_per_minute: 5,
1819            pre_auth_max_per_minute: None,
1820            ..Default::default()
1821        };
1822        let limiter = build_pre_auth_limiter(&config);
1823        let ip: IpAddr = "10.0.0.1".parse().unwrap();
1824
1825        // Quota should be 50 (5 * 10), not 5. We expect the first 50 to pass.
1826        for i in 0..50 {
1827            assert!(
1828                limiter.check_key(&ip).is_ok(),
1829                "pre-auth attempt {i} (of expected 50) should be allowed under default 10x multiplier"
1830            );
1831        }
1832        // The 51st attempt must be blocked: confirms quota is bounded, not infinite.
1833        assert!(
1834            limiter.check_key(&ip).is_err(),
1835            "pre-auth attempt 51 should be blocked (quota is 50, not unbounded)"
1836        );
1837    }
1838
1839    /// An explicit `pre_auth_max_per_minute` override must win over the
1840    /// 10x-multiplier default.
1841    #[test]
1842    fn pre_auth_explicit_override_wins() {
1843        let config = RateLimitConfig {
1844            max_attempts_per_minute: 100,     // would default to 1000 pre-auth quota
1845            pre_auth_max_per_minute: Some(2), // but operator caps at 2
1846            ..Default::default()
1847        };
1848        let limiter = build_pre_auth_limiter(&config);
1849        let ip: IpAddr = "10.0.0.2".parse().unwrap();
1850
1851        assert!(limiter.check_key(&ip).is_ok(), "attempt 1 allowed");
1852        assert!(limiter.check_key(&ip).is_ok(), "attempt 2 allowed");
1853        assert!(
1854            limiter.check_key(&ip).is_err(),
1855            "attempt 3 must be blocked (explicit override of 2 wins over 10x default of 1000)"
1856        );
1857    }
1858
1859    /// End-to-end: the pre-auth gate must reject before the bearer-verification
1860    /// path runs. We exhaust the gate's quota (Some(1)) with one bad-bearer
1861    /// request, then the second request must be rejected with 429 + the
1862    /// `pre_auth_gate` failure counter incremented (NOT
1863    /// `failure_invalid_credential`, which would prove Argon2 ran).
1864    #[tokio::test]
1865    async fn pre_auth_gate_blocks_before_argon2_verification() {
1866        let (_token, hash) = generate_api_key().unwrap();
1867        let keys = vec![ApiKeyEntry {
1868            name: "test-key".into(),
1869            hash,
1870            role: "ops".into(),
1871            expires_at: None,
1872        }];
1873        let config = RateLimitConfig {
1874            max_attempts_per_minute: 100,
1875            pre_auth_max_per_minute: Some(1),
1876            ..Default::default()
1877        };
1878        let state = Arc::new(AuthState {
1879            api_keys: ArcSwap::new(Arc::new(keys)),
1880            rate_limiter: None,
1881            pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1882            #[cfg(feature = "oauth")]
1883            jwks_cache: None,
1884            seen_identities: Mutex::new(HashSet::new()),
1885            counters: AuthCounters::default(),
1886        });
1887        let app = auth_router(Arc::clone(&state));
1888        let peer: SocketAddr = "10.0.0.10:54321".parse().unwrap();
1889
1890        // First bad-bearer request: gate has quota, bearer verification runs,
1891        // returns 401 (invalid credential).
1892        let mut req1 = Request::builder()
1893            .method(axum::http::Method::POST)
1894            .uri("/mcp")
1895            .header("authorization", "Bearer obviously-not-a-real-token")
1896            .body(Body::empty())
1897            .unwrap();
1898        req1.extensions_mut().insert(ConnectInfo(peer));
1899        let resp1 = app.clone().oneshot(req1).await.unwrap();
1900        assert_eq!(
1901            resp1.status(),
1902            StatusCode::UNAUTHORIZED,
1903            "first attempt: gate has quota, falls through to bearer auth which fails with 401"
1904        );
1905
1906        // Second bad-bearer request from same IP: gate quota exhausted, must
1907        // reject with 429 BEFORE the Argon2 verification path runs.
1908        let mut req2 = Request::builder()
1909            .method(axum::http::Method::POST)
1910            .uri("/mcp")
1911            .header("authorization", "Bearer also-not-a-real-token")
1912            .body(Body::empty())
1913            .unwrap();
1914        req2.extensions_mut().insert(ConnectInfo(peer));
1915        let resp2 = app.oneshot(req2).await.unwrap();
1916        assert_eq!(
1917            resp2.status(),
1918            StatusCode::TOO_MANY_REQUESTS,
1919            "second attempt from same IP: pre-auth gate must reject with 429"
1920        );
1921
1922        let counters = state.counters_snapshot();
1923        assert_eq!(
1924            counters.failure_pre_auth_gate, 1,
1925            "exactly one request must have been rejected by the pre-auth gate"
1926        );
1927        // Critical: Argon2 verification must NOT have run on the gated request.
1928        // The first request's 401 increments `failure_invalid_credential` to 1;
1929        // the second (gated) request must NOT increment it further.
1930        assert_eq!(
1931            counters.failure_invalid_credential, 1,
1932            "bearer verification must run exactly once (only the un-gated first request)"
1933        );
1934    }
1935
1936    /// mTLS-authenticated requests must bypass the pre-auth gate entirely.
1937    /// The TLS handshake already performed expensive crypto with a verified
1938    /// peer, so mTLS callers should never be throttled by this gate.
1939    ///
1940    /// Setup: a pre-auth gate with quota 1 (very tight). Submit two mTLS
1941    /// requests in quick succession from the same IP. Both must succeed.
1942    #[tokio::test]
1943    async fn pre_auth_gate_does_not_throttle_mtls() {
1944        let config = RateLimitConfig {
1945            max_attempts_per_minute: 100,
1946            pre_auth_max_per_minute: Some(1), // tight: would block 2nd plain request
1947            ..Default::default()
1948        };
1949        let state = Arc::new(AuthState {
1950            api_keys: ArcSwap::new(Arc::new(vec![])),
1951            rate_limiter: None,
1952            pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1953            #[cfg(feature = "oauth")]
1954            jwks_cache: None,
1955            seen_identities: Mutex::new(HashSet::new()),
1956            counters: AuthCounters::default(),
1957        });
1958        let app = auth_router(Arc::clone(&state));
1959        let peer: SocketAddr = "10.0.0.20:54321".parse().unwrap();
1960        let identity = AuthIdentity {
1961            name: "cn=test-client".into(),
1962            role: "viewer".into(),
1963            method: AuthMethod::MtlsCertificate,
1964            raw_token: None,
1965            sub: None,
1966        };
1967        let tls_info = TlsConnInfo::new(peer, Some(identity));
1968
1969        for i in 0..3 {
1970            let mut req = Request::builder()
1971                .method(axum::http::Method::POST)
1972                .uri("/mcp")
1973                .body(Body::empty())
1974                .unwrap();
1975            req.extensions_mut().insert(ConnectInfo(tls_info.clone()));
1976            let resp = app.clone().oneshot(req).await.unwrap();
1977            assert_eq!(
1978                resp.status(),
1979                StatusCode::OK,
1980                "mTLS request {i} must succeed: pre-auth gate must not apply to mTLS callers"
1981            );
1982        }
1983
1984        let counters = state.counters_snapshot();
1985        assert_eq!(
1986            counters.failure_pre_auth_gate, 0,
1987            "pre-auth gate counter must remain at zero: mTLS bypasses the gate"
1988        );
1989        assert_eq!(
1990            counters.success_mtls, 3,
1991            "all three mTLS requests must have been counted as successful"
1992        );
1993    }
1994
1995    // -------------------------------------------------------------------
1996    // RFC 7235 §2.1 case-insensitive scheme parsing for `extract_bearer`.
1997    // -------------------------------------------------------------------
1998
1999    #[test]
2000    fn extract_bearer_accepts_canonical_case() {
2001        assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2002    }
2003
2004    #[test]
2005    fn extract_bearer_is_case_insensitive_per_rfc7235() {
2006        // RFC 7235 §2.1: "auth-scheme is case-insensitive".
2007        // Real-world clients (curl, browsers, custom HTTP libs) emit varied
2008        // casings; rejecting any of them is a spec violation.
2009        for header in &[
2010            "bearer abc123",
2011            "BEARER abc123",
2012            "BeArEr abc123",
2013            "bEaReR abc123",
2014        ] {
2015            assert_eq!(
2016                extract_bearer(header),
2017                Some("abc123"),
2018                "header {header:?} must parse as a Bearer token (RFC 7235 §2.1)"
2019            );
2020        }
2021    }
2022
2023    #[test]
2024    fn extract_bearer_rejects_other_schemes() {
2025        assert_eq!(extract_bearer("Basic dXNlcjpwYXNz"), None);
2026        assert_eq!(extract_bearer("Digest username=\"x\""), None);
2027        assert_eq!(extract_bearer("Token abc123"), None);
2028    }
2029
2030    #[test]
2031    fn extract_bearer_rejects_malformed() {
2032        // Empty string, no separator, scheme-only, scheme + only whitespace.
2033        assert_eq!(extract_bearer(""), None);
2034        assert_eq!(extract_bearer("Bearer"), None);
2035        assert_eq!(extract_bearer("Bearer "), None);
2036        assert_eq!(extract_bearer("Bearer    "), None);
2037    }
2038
2039    #[test]
2040    fn extract_bearer_tolerates_extra_separator_whitespace() {
2041        // Some non-conformant clients emit two spaces; we should still parse.
2042        assert_eq!(extract_bearer("Bearer  abc123"), Some("abc123"));
2043        assert_eq!(extract_bearer("Bearer   abc123"), Some("abc123"));
2044    }
2045
2046    // -------------------------------------------------------------------
2047    // Debug redaction: ensure `AuthIdentity` and `ApiKeyEntry` never leak
2048    // secret material via `format!("{:?}", …)` or `tracing::debug!(?…)`.
2049    // -------------------------------------------------------------------
2050
2051    #[test]
2052    fn auth_identity_debug_redacts_raw_token() {
2053        let id = AuthIdentity {
2054            name: "alice".into(),
2055            role: "admin".into(),
2056            method: AuthMethod::OAuthJwt,
2057            raw_token: Some(SecretString::from("super-secret-jwt-payload-xyz")),
2058            sub: Some("keycloak-uuid-2f3c8b".into()),
2059        };
2060        let dbg = format!("{id:?}");
2061
2062        // Plaintext fields must be visible (they are not secrets).
2063        assert!(dbg.contains("alice"), "name should be visible: {dbg}");
2064        assert!(dbg.contains("admin"), "role should be visible: {dbg}");
2065        assert!(dbg.contains("OAuthJwt"), "method should be visible: {dbg}");
2066
2067        // Secret fields must NOT leak.
2068        assert!(
2069            !dbg.contains("super-secret-jwt-payload-xyz"),
2070            "raw_token must be redacted in Debug output: {dbg}"
2071        );
2072        assert!(
2073            !dbg.contains("keycloak-uuid-2f3c8b"),
2074            "sub must be redacted in Debug output: {dbg}"
2075        );
2076        assert!(
2077            dbg.contains("<redacted>"),
2078            "redaction marker missing: {dbg}"
2079        );
2080    }
2081
2082    #[test]
2083    fn auth_identity_debug_marks_absent_secrets() {
2084        // For non-OAuth identities (mTLS / API key) the secret fields are
2085        // None; redacted Debug output should distinguish that from "present".
2086        let id = AuthIdentity {
2087            name: "viewer-key".into(),
2088            role: "viewer".into(),
2089            method: AuthMethod::BearerToken,
2090            raw_token: None,
2091            sub: None,
2092        };
2093        let dbg = format!("{id:?}");
2094        assert!(
2095            dbg.contains("<none>"),
2096            "absent secrets should be marked: {dbg}"
2097        );
2098        assert!(
2099            !dbg.contains("<redacted>"),
2100            "no <redacted> marker when secrets are absent: {dbg}"
2101        );
2102    }
2103
2104    #[test]
2105    fn api_key_entry_debug_redacts_hash() {
2106        let entry = ApiKeyEntry {
2107            name: "viewer-key".into(),
2108            // Realistic Argon2id PHC string (must NOT leak).
2109            hash: "$argon2id$v=19$m=19456,t=2,p=1$c2FsdHNhbHQ$h4sh3dPa55w0rd".into(),
2110            role: "viewer".into(),
2111            expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
2112        };
2113        let dbg = format!("{entry:?}");
2114
2115        // Non-secret fields visible.
2116        assert!(dbg.contains("viewer-key"));
2117        assert!(dbg.contains("viewer"));
2118        assert!(dbg.contains("2030-01-01T00:00:00+00:00"));
2119
2120        // Hash material must NOT leak.
2121        assert!(
2122            !dbg.contains("$argon2id$"),
2123            "argon2 hash leaked into Debug output: {dbg}"
2124        );
2125        assert!(
2126            !dbg.contains("h4sh3dPa55w0rd"),
2127            "hash digest leaked into Debug output: {dbg}"
2128        );
2129        assert!(
2130            dbg.contains("<redacted>"),
2131            "redaction marker missing: {dbg}"
2132        );
2133    }
2134
2135    // -- AuthFailureClass exact-string contract tests --
2136    //
2137    // These tests pin the exact wire strings emitted for each failure
2138    // class. They exist to kill mutation-test mutants that replace the
2139    // match-arm string literals (e.g. with `""` or with the value from
2140    // another arm). Operators and dashboards rely on these literals
2141    // for metric labels and audit-log filters; any change is a
2142    // breaking observability change and must be reflected in
2143    // CHANGELOG.md.
2144
2145    #[test]
2146    fn auth_failure_class_as_str_exact_strings() {
2147        assert_eq!(
2148            AuthFailureClass::MissingCredential.as_str(),
2149            "missing_credential"
2150        );
2151        assert_eq!(
2152            AuthFailureClass::InvalidCredential.as_str(),
2153            "invalid_credential"
2154        );
2155        assert_eq!(
2156            AuthFailureClass::ExpiredCredential.as_str(),
2157            "expired_credential"
2158        );
2159        assert_eq!(AuthFailureClass::RateLimited.as_str(), "rate_limited");
2160        assert_eq!(AuthFailureClass::PreAuthGate.as_str(), "pre_auth_gate");
2161    }
2162
2163    #[test]
2164    fn auth_failure_class_response_body_exact_strings() {
2165        assert_eq!(
2166            AuthFailureClass::MissingCredential.response_body(),
2167            "unauthorized: missing credential"
2168        );
2169        assert_eq!(
2170            AuthFailureClass::InvalidCredential.response_body(),
2171            "unauthorized: invalid credential"
2172        );
2173        assert_eq!(
2174            AuthFailureClass::ExpiredCredential.response_body(),
2175            "unauthorized: expired credential"
2176        );
2177        assert_eq!(
2178            AuthFailureClass::RateLimited.response_body(),
2179            "rate limited"
2180        );
2181        assert_eq!(
2182            AuthFailureClass::PreAuthGate.response_body(),
2183            "rate limited (pre-auth)"
2184        );
2185    }
2186
2187    #[test]
2188    fn auth_failure_class_bearer_error_exact_strings() {
2189        assert_eq!(
2190            AuthFailureClass::MissingCredential.bearer_error(),
2191            (
2192                "invalid_request",
2193                "missing bearer token or mTLS client certificate"
2194            )
2195        );
2196        assert_eq!(
2197            AuthFailureClass::InvalidCredential.bearer_error(),
2198            ("invalid_token", "token is invalid")
2199        );
2200        assert_eq!(
2201            AuthFailureClass::ExpiredCredential.bearer_error(),
2202            ("invalid_token", "token is expired")
2203        );
2204        assert_eq!(
2205            AuthFailureClass::RateLimited.bearer_error(),
2206            ("invalid_request", "too many failed authentication attempts")
2207        );
2208        assert_eq!(
2209            AuthFailureClass::PreAuthGate.bearer_error(),
2210            (
2211                "invalid_request",
2212                "too many unauthenticated requests from this source"
2213            )
2214        );
2215    }
2216
2217    // -- AuthConfig::summary boolean-flag contract tests --
2218    //
2219    // These tests pin the boolean flags emitted by `AuthConfig::summary`
2220    // so that mutations like deleting `!` (which would invert the
2221    // semantics of `bearer`) or replacing `is_some()` with `is_none()`
2222    // are caught immediately. The summary is consumed by `/admin/*`
2223    // diagnostics so any inversion is an operator-visible regression.
2224
2225    #[test]
2226    fn auth_config_summary_bearer_true_when_keys_present() {
2227        let (_token, hash) = generate_api_key().unwrap();
2228        let cfg = AuthConfig::with_keys(vec![ApiKeyEntry::new("k", hash, "viewer")]);
2229        let s = cfg.summary();
2230        assert!(s.enabled, "summary.enabled must reflect AuthConfig.enabled");
2231        assert!(
2232            s.bearer,
2233            "summary.bearer must be true when api_keys is non-empty (kills `!` deletion at L615)"
2234        );
2235        assert!(!s.mtls, "summary.mtls must be false when mtls is None");
2236        assert!(!s.oauth, "summary.oauth must be false when oauth is None");
2237        assert_eq!(s.api_keys.len(), 1);
2238        assert_eq!(s.api_keys[0].name, "k");
2239        assert_eq!(s.api_keys[0].role, "viewer");
2240    }
2241
2242    #[test]
2243    fn auth_config_summary_bearer_false_when_no_keys() {
2244        let cfg = AuthConfig::with_keys(vec![]);
2245        let s = cfg.summary();
2246        assert!(
2247            !s.bearer,
2248            "summary.bearer must be false when api_keys is empty (kills `!` deletion at L615)"
2249        );
2250        assert!(s.api_keys.is_empty());
2251    }
2252}