Skip to main content

rmcp_server_kit/
auth.rs

1//! Authentication middleware for MCP servers.
2//!
3//! Supports multiple authentication methods tried in priority order:
4//! 1. mTLS client certificate (if configured and peer cert present)
5//! 2. Bearer token (API key) with Argon2id hash verification
6//!
7//! Includes per-source-IP rate limiting on authentication attempts.
8
9use std::{
10    collections::HashSet,
11    net::{IpAddr, SocketAddr},
12    num::NonZeroU32,
13    path::PathBuf,
14    sync::{
15        Arc, Mutex,
16        atomic::{AtomicU64, Ordering},
17    },
18    time::Duration,
19};
20
21use arc_swap::ArcSwap;
22use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
23use axum::{
24    body::Body,
25    extract::ConnectInfo,
26    http::{Request, header},
27    middleware::Next,
28    response::{IntoResponse, Response},
29};
30use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
31use secrecy::SecretString;
32use serde::Deserialize;
33use x509_parser::prelude::*;
34
35use crate::{bounded_limiter::BoundedKeyedLimiter, error::McpxError};
36
37/// Identity of an authenticated caller.
38#[derive(Debug, Clone)]
39#[non_exhaustive]
40pub struct AuthIdentity {
41    /// Human-readable identity name (e.g. API key label or cert CN).
42    pub name: String,
43    /// RBAC role associated with this identity.
44    pub role: String,
45    /// Which authentication mechanism produced this identity.
46    pub method: AuthMethod,
47    /// Raw bearer token from the `Authorization` header, wrapped in
48    /// [`SecretString`] so it is never accidentally logged or serialized.
49    /// Present for OAuth JWT; `None` for mTLS and API-key auth.
50    /// Tool handlers use this for downstream token passthrough via
51    /// [`crate::rbac::current_token`].
52    pub raw_token: Option<SecretString>,
53    /// JWT `sub` claim (stable user identifier, e.g. Keycloak UUID).
54    /// Used for token store keying. `None` for non-JWT auth.
55    pub sub: Option<String>,
56}
57
58/// How the caller authenticated.
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60#[non_exhaustive]
61pub enum AuthMethod {
62    /// Bearer API key (Argon2id-hashed, configured statically).
63    BearerToken,
64    /// Mutual TLS client certificate.
65    MtlsCertificate,
66    /// OAuth 2.1 JWT bearer token (validated via JWKS).
67    OAuthJwt,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71enum AuthFailureClass {
72    MissingCredential,
73    InvalidCredential,
74    #[cfg_attr(not(feature = "oauth"), allow(dead_code))]
75    ExpiredCredential,
76    /// Source IP exceeded the post-failure backoff limit.
77    RateLimited,
78    /// Source IP exceeded the pre-auth abuse gate (rejected before any
79    /// password-hash work — see [`AuthState::pre_auth_limiter`]).
80    PreAuthGate,
81}
82
83impl AuthFailureClass {
84    fn as_str(self) -> &'static str {
85        match self {
86            Self::MissingCredential => "missing_credential",
87            Self::InvalidCredential => "invalid_credential",
88            Self::ExpiredCredential => "expired_credential",
89            Self::RateLimited => "rate_limited",
90            Self::PreAuthGate => "pre_auth_gate",
91        }
92    }
93
94    fn bearer_error(self) -> (&'static str, &'static str) {
95        match self {
96            Self::MissingCredential => (
97                "invalid_request",
98                "missing bearer token or mTLS client certificate",
99            ),
100            Self::InvalidCredential => ("invalid_token", "token is invalid"),
101            Self::ExpiredCredential => ("invalid_token", "token is expired"),
102            Self::RateLimited => ("invalid_request", "too many failed authentication attempts"),
103            Self::PreAuthGate => (
104                "invalid_request",
105                "too many unauthenticated requests from this source",
106            ),
107        }
108    }
109
110    fn response_body(self) -> &'static str {
111        match self {
112            Self::MissingCredential => "unauthorized: missing credential",
113            Self::InvalidCredential => "unauthorized: invalid credential",
114            Self::ExpiredCredential => "unauthorized: expired credential",
115            Self::RateLimited => "rate limited",
116            Self::PreAuthGate => "rate limited (pre-auth)",
117        }
118    }
119}
120
121/// Snapshot of authentication success/failure counters.
122#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
123#[non_exhaustive]
124pub struct AuthCountersSnapshot {
125    /// Successful mTLS authentications.
126    pub success_mtls: u64,
127    /// Successful bearer-token authentications.
128    pub success_bearer: u64,
129    /// Successful OAuth JWT authentications.
130    pub success_oauth_jwt: u64,
131    /// Failures because no credential was presented.
132    pub failure_missing_credential: u64,
133    /// Failures because the credential was malformed or wrong.
134    pub failure_invalid_credential: u64,
135    /// Failures because the credential had expired.
136    pub failure_expired_credential: u64,
137    /// Failures because the source IP was rate-limited (post-failure backoff).
138    pub failure_rate_limited: u64,
139    /// Failures because the source IP exceeded the pre-auth abuse gate.
140    /// These never reach the password-hash verification path.
141    pub failure_pre_auth_gate: u64,
142}
143
144/// Internal atomic counters backing [`AuthCountersSnapshot`].
145#[derive(Debug, Default)]
146pub(crate) struct AuthCounters {
147    success_mtls: AtomicU64,
148    success_bearer: AtomicU64,
149    success_oauth_jwt: AtomicU64,
150    failure_missing_credential: AtomicU64,
151    failure_invalid_credential: AtomicU64,
152    failure_expired_credential: AtomicU64,
153    failure_rate_limited: AtomicU64,
154    failure_pre_auth_gate: AtomicU64,
155}
156
157impl AuthCounters {
158    fn record_success(&self, method: AuthMethod) {
159        match method {
160            AuthMethod::MtlsCertificate => {
161                self.success_mtls.fetch_add(1, Ordering::Relaxed);
162            }
163            AuthMethod::BearerToken => {
164                self.success_bearer.fetch_add(1, Ordering::Relaxed);
165            }
166            AuthMethod::OAuthJwt => {
167                self.success_oauth_jwt.fetch_add(1, Ordering::Relaxed);
168            }
169        }
170    }
171
172    fn record_failure(&self, class: AuthFailureClass) {
173        match class {
174            AuthFailureClass::MissingCredential => {
175                self.failure_missing_credential
176                    .fetch_add(1, Ordering::Relaxed);
177            }
178            AuthFailureClass::InvalidCredential => {
179                self.failure_invalid_credential
180                    .fetch_add(1, Ordering::Relaxed);
181            }
182            AuthFailureClass::ExpiredCredential => {
183                self.failure_expired_credential
184                    .fetch_add(1, Ordering::Relaxed);
185            }
186            AuthFailureClass::RateLimited => {
187                self.failure_rate_limited.fetch_add(1, Ordering::Relaxed);
188            }
189            AuthFailureClass::PreAuthGate => {
190                self.failure_pre_auth_gate.fetch_add(1, Ordering::Relaxed);
191            }
192        }
193    }
194
195    fn snapshot(&self) -> AuthCountersSnapshot {
196        AuthCountersSnapshot {
197            success_mtls: self.success_mtls.load(Ordering::Relaxed),
198            success_bearer: self.success_bearer.load(Ordering::Relaxed),
199            success_oauth_jwt: self.success_oauth_jwt.load(Ordering::Relaxed),
200            failure_missing_credential: self.failure_missing_credential.load(Ordering::Relaxed),
201            failure_invalid_credential: self.failure_invalid_credential.load(Ordering::Relaxed),
202            failure_expired_credential: self.failure_expired_credential.load(Ordering::Relaxed),
203            failure_rate_limited: self.failure_rate_limited.load(Ordering::Relaxed),
204            failure_pre_auth_gate: self.failure_pre_auth_gate.load(Ordering::Relaxed),
205        }
206    }
207}
208
209/// A single API key entry (stored as Argon2id hash in config).
210#[derive(Debug, Clone, Deserialize)]
211#[non_exhaustive]
212pub struct ApiKeyEntry {
213    /// Human-readable key label (used in logs and audit records).
214    pub name: String,
215    /// Argon2id hash of the token (PHC string format).
216    pub hash: String,
217    /// RBAC role granted when this key authenticates successfully.
218    pub role: String,
219    /// Optional expiry in RFC 3339 format.
220    pub expires_at: Option<String>,
221}
222
223impl ApiKeyEntry {
224    /// Create a new API key entry (no expiry).
225    #[must_use]
226    pub fn new(name: impl Into<String>, hash: impl Into<String>, role: impl Into<String>) -> Self {
227        Self {
228            name: name.into(),
229            hash: hash.into(),
230            role: role.into(),
231            expires_at: None,
232        }
233    }
234
235    /// Set an RFC 3339 expiry on this key.
236    #[must_use]
237    pub fn with_expiry(mut self, expires_at: impl Into<String>) -> Self {
238        self.expires_at = Some(expires_at.into());
239        self
240    }
241}
242
243/// mTLS client certificate authentication configuration.
244#[derive(Debug, Clone, Deserialize)]
245#[allow(
246    clippy::struct_excessive_bools,
247    reason = "mTLS CRL behavior is intentionally configured as independent booleans"
248)]
249#[non_exhaustive]
250pub struct MtlsConfig {
251    /// Path to CA certificate(s) for verifying client certs (PEM format).
252    pub ca_cert_path: PathBuf,
253    /// If true, clients MUST present a valid certificate.
254    /// If false, client certs are optional (verified if presented).
255    #[serde(default)]
256    pub required: bool,
257    /// Default RBAC role for mTLS-authenticated clients.
258    /// The client cert CN becomes the identity name.
259    #[serde(default = "default_mtls_role")]
260    pub default_role: String,
261    /// Enable CRL-based certificate revocation checks using CDP URLs from the
262    /// configured CA chain and connecting client certificates.
263    #[serde(default = "default_true")]
264    pub crl_enabled: bool,
265    /// Optional fixed refresh interval for known CRLs. When omitted, refresh
266    /// cadence is derived from `nextUpdate` and clamped internally.
267    #[serde(default, with = "humantime_serde::option")]
268    pub crl_refresh_interval: Option<Duration>,
269    /// Timeout for individual CRL fetches.
270    #[serde(default = "default_crl_fetch_timeout", with = "humantime_serde")]
271    pub crl_fetch_timeout: Duration,
272    /// Grace window during which stale CRLs may still be used when refresh
273    /// attempts fail.
274    #[serde(default = "default_crl_stale_grace", with = "humantime_serde")]
275    pub crl_stale_grace: Duration,
276    /// When true, missing or unavailable CRLs cause revocation checks to fail
277    /// closed.
278    #[serde(default)]
279    pub crl_deny_on_unavailable: bool,
280    /// When true, apply revocation checks only to the end-entity certificate.
281    #[serde(default)]
282    pub crl_end_entity_only: bool,
283    /// Allow HTTP CRL distribution-point URLs in addition to HTTPS.
284    ///
285    /// Defaults to `true` because RFC 5280 §4.2.1.13 designates HTTP (and
286    /// LDAP) as the canonical transport for CRL distribution points.
287    /// SSRF defense for HTTP CDPs is provided by the IP-allowlist guard
288    /// (private/loopback/link-local/multicast/cloud-metadata addresses are
289    /// always rejected), redirect=none, body-size cap, and per-host
290    /// concurrency limit -- not by forcing HTTPS.
291    #[serde(default = "default_true")]
292    pub crl_allow_http: bool,
293    /// Enforce CRL expiration during certificate validation.
294    #[serde(default = "default_true")]
295    pub crl_enforce_expiration: bool,
296    /// Maximum concurrent CRL fetches across all hosts. Defense in depth
297    /// against SSRF amplification: even if many CDPs are discovered, no
298    /// more than this many fetches run in parallel. Per-host concurrency
299    /// is independently capped at 1 regardless of this value.
300    /// Default: `4`.
301    #[serde(default = "default_crl_max_concurrent_fetches")]
302    pub crl_max_concurrent_fetches: usize,
303    /// Hard cap on each CRL response body in bytes. Fetches exceeding this
304    /// are aborted mid-stream to bound memory and prevent gzip-bomb-style
305    /// amplification. Default: 5 MiB (`5 * 1024 * 1024`).
306    #[serde(default = "default_crl_max_response_bytes")]
307    pub crl_max_response_bytes: u64,
308    /// Global CDP discovery rate limit, in URLs per minute. Throttles
309    /// how many *new* CDP URLs the verifier may admit into the fetch
310    /// pipeline across the whole process, bounding asymmetric `DoS`
311    /// amplification when attacker-controlled certificates carry large
312    /// CDP lists. The limit is global (not per-source-IP) in this
313    /// release; per-IP scoping is deferred to a future version because
314    /// it requires plumbing the peer `SocketAddr` through the verifier
315    /// hook. URLs that lose the rate-limiter race are *not* marked as
316    /// seen, so subsequent handshakes observing the same URL can
317    /// retry admission.
318    /// Default: `60`.
319    #[serde(default = "default_crl_discovery_rate_per_min")]
320    pub crl_discovery_rate_per_min: u32,
321    /// Maximum number of distinct hosts that may hold a CRL fetch
322    /// semaphore at any time. Requests that would grow the map beyond
323    /// this cap return [`McpxError::Config`] containing the literal
324    /// substring `"crl_host_semaphore_cap_exceeded"`. Bounds memory
325    /// growth from attacker-controlled CDP URLs pointing at unique
326    /// hostnames. Default: 1024.
327    #[serde(default = "default_crl_max_host_semaphores")]
328    pub crl_max_host_semaphores: usize,
329    /// Maximum number of distinct URLs tracked in the "seen" set.
330    /// Beyond this, additional discovered URLs are silently dropped
331    /// with a rate-limited warn! log; no error surfaces. Default: 4096.
332    #[serde(default = "default_crl_max_seen_urls")]
333    pub crl_max_seen_urls: usize,
334    /// Maximum number of cached CRL entries. Beyond this, new
335    /// successful fetches are silently dropped with a rate-limited
336    /// warn! log (newest-rejected, not LRU-evicted). Default: 1024.
337    #[serde(default = "default_crl_max_cache_entries")]
338    pub crl_max_cache_entries: usize,
339}
340
341fn default_mtls_role() -> String {
342    "viewer".into()
343}
344
345const fn default_true() -> bool {
346    true
347}
348
349const fn default_crl_fetch_timeout() -> Duration {
350    Duration::from_secs(30)
351}
352
353const fn default_crl_stale_grace() -> Duration {
354    Duration::from_hours(24)
355}
356
357const fn default_crl_max_concurrent_fetches() -> usize {
358    4
359}
360
361const fn default_crl_max_response_bytes() -> u64 {
362    5 * 1024 * 1024
363}
364
365const fn default_crl_discovery_rate_per_min() -> u32 {
366    60
367}
368
369const fn default_crl_max_host_semaphores() -> usize {
370    1024
371}
372
373const fn default_crl_max_seen_urls() -> usize {
374    4096
375}
376
377const fn default_crl_max_cache_entries() -> usize {
378    1024
379}
380
381/// Rate limiting configuration for authentication attempts.
382///
383/// rmcp-server-kit uses two independent per-IP token-bucket limiters for auth:
384///
385/// 1. **Pre-auth abuse gate** ([`Self::pre_auth_max_per_minute`]): consulted
386///    *before* any password-hash work. Throttles unauthenticated traffic from
387///    a single source IP so an attacker cannot pin the CPU on Argon2id by
388///    spraying invalid bearer tokens. Sized generously (default = 10× the
389///    post-failure quota) so legitimate clients are unaffected. mTLS-
390///    authenticated connections bypass this gate entirely (the TLS handshake
391///    already performed expensive crypto with a verified peer).
392/// 2. **Post-failure backoff** ([`Self::max_attempts_per_minute`]): consulted
393///    *after* an authentication attempt fails. Provides explicit backpressure
394///    on bad credentials.
395#[derive(Debug, Clone, Deserialize)]
396#[non_exhaustive]
397pub struct RateLimitConfig {
398    /// Maximum failed authentication attempts per source IP per minute.
399    /// Successful authentications do not consume this budget.
400    #[serde(default = "default_max_attempts")]
401    pub max_attempts_per_minute: u32,
402    /// Maximum *unauthenticated* requests per source IP per minute admitted
403    /// to the password-hash verification path. When `None`, defaults to
404    /// `max_attempts_per_minute * 10` at limiter-construction time.
405    ///
406    /// Set higher than [`Self::max_attempts_per_minute`] so honest clients
407    /// retrying with the wrong key never trip this gate; its purpose is only
408    /// to bound CPU usage under spray attacks.
409    #[serde(default)]
410    pub pre_auth_max_per_minute: Option<u32>,
411    /// Hard cap on the number of distinct source IPs tracked per limiter.
412    /// When reached, idle entries are pruned first; if still full, the
413    /// oldest (LRU) entry is evicted to make room for the new one. This
414    /// bounds memory under IP-spray attacks. Default: `10_000`.
415    #[serde(default = "default_max_tracked_keys")]
416    pub max_tracked_keys: usize,
417    /// Per-IP entries idle for longer than this are eligible for
418    /// opportunistic pruning. Default: 15 minutes.
419    #[serde(default = "default_idle_eviction", with = "humantime_serde")]
420    pub idle_eviction: Duration,
421}
422
423impl Default for RateLimitConfig {
424    fn default() -> Self {
425        Self {
426            max_attempts_per_minute: default_max_attempts(),
427            pre_auth_max_per_minute: None,
428            max_tracked_keys: default_max_tracked_keys(),
429            idle_eviction: default_idle_eviction(),
430        }
431    }
432}
433
434impl RateLimitConfig {
435    /// Create a rate limit config with the given max failed attempts per minute.
436    /// Pre-auth gate defaults to `10x` this value at limiter-construction time.
437    /// Memory-bound defaults are `10_000` tracked keys with 15-minute idle eviction.
438    #[must_use]
439    pub fn new(max_attempts_per_minute: u32) -> Self {
440        Self {
441            max_attempts_per_minute,
442            ..Self::default()
443        }
444    }
445
446    /// Override the pre-auth abuse-gate quota (per source IP per minute).
447    /// When unset, defaults to `max_attempts_per_minute * 10`.
448    #[must_use]
449    pub fn with_pre_auth_max_per_minute(mut self, quota: u32) -> Self {
450        self.pre_auth_max_per_minute = Some(quota);
451        self
452    }
453
454    /// Override the per-limiter cap on tracked source-IP keys (default `10_000`).
455    #[must_use]
456    pub fn with_max_tracked_keys(mut self, max: usize) -> Self {
457        self.max_tracked_keys = max;
458        self
459    }
460
461    /// Override the idle-eviction window (default 15 minutes).
462    #[must_use]
463    pub fn with_idle_eviction(mut self, idle: Duration) -> Self {
464        self.idle_eviction = idle;
465        self
466    }
467}
468
469fn default_max_attempts() -> u32 {
470    30
471}
472
473fn default_max_tracked_keys() -> usize {
474    10_000
475}
476
477fn default_idle_eviction() -> Duration {
478    Duration::from_mins(15)
479}
480
481/// Authentication configuration.
482#[derive(Debug, Clone, Default, Deserialize)]
483#[non_exhaustive]
484pub struct AuthConfig {
485    /// Master switch - when false, all requests are allowed through.
486    #[serde(default)]
487    pub enabled: bool,
488    /// Bearer token API keys.
489    #[serde(default)]
490    pub api_keys: Vec<ApiKeyEntry>,
491    /// mTLS client certificate authentication.
492    pub mtls: Option<MtlsConfig>,
493    /// Rate limiting for auth attempts.
494    pub rate_limit: Option<RateLimitConfig>,
495    /// OAuth 2.1 JWT bearer token authentication.
496    #[cfg(feature = "oauth")]
497    pub oauth: Option<crate::oauth::OAuthConfig>,
498}
499
500impl AuthConfig {
501    /// Create an enabled auth config with the given API keys.
502    #[must_use]
503    pub fn with_keys(keys: Vec<ApiKeyEntry>) -> Self {
504        Self {
505            enabled: true,
506            api_keys: keys,
507            mtls: None,
508            rate_limit: None,
509            #[cfg(feature = "oauth")]
510            oauth: None,
511        }
512    }
513
514    /// Set rate limiting on this auth config.
515    #[must_use]
516    pub fn with_rate_limit(mut self, rate_limit: RateLimitConfig) -> Self {
517        self.rate_limit = Some(rate_limit);
518        self
519    }
520}
521
522/// Summary of a single API key suitable for admin endpoints.
523///
524/// Intentionally omits the Argon2id hash - only metadata is exposed.
525#[derive(Debug, Clone, serde::Serialize)]
526#[non_exhaustive]
527pub struct ApiKeySummary {
528    /// Human-readable key label.
529    pub name: String,
530    /// RBAC role granted when this key authenticates.
531    pub role: String,
532    /// Optional RFC 3339 expiry timestamp.
533    pub expires_at: Option<String>,
534}
535
536/// Snapshot of the enabled authentication methods for admin endpoints.
537#[derive(Debug, Clone, serde::Serialize)]
538#[allow(
539    clippy::struct_excessive_bools,
540    reason = "this is a flat summary of independent auth-method booleans"
541)]
542#[non_exhaustive]
543pub struct AuthConfigSummary {
544    /// Master enabled flag from config.
545    pub enabled: bool,
546    /// Whether API-key bearer auth is configured.
547    pub bearer: bool,
548    /// Whether mTLS client auth is configured.
549    pub mtls: bool,
550    /// Whether OAuth JWT validation is configured.
551    pub oauth: bool,
552    /// Current API-key list (no hashes).
553    pub api_keys: Vec<ApiKeySummary>,
554}
555
556impl AuthConfig {
557    /// Produce a hash-free summary of the auth config for admin endpoints.
558    #[must_use]
559    pub fn summary(&self) -> AuthConfigSummary {
560        AuthConfigSummary {
561            enabled: self.enabled,
562            bearer: !self.api_keys.is_empty(),
563            mtls: self.mtls.is_some(),
564            #[cfg(feature = "oauth")]
565            oauth: self.oauth.is_some(),
566            #[cfg(not(feature = "oauth"))]
567            oauth: false,
568            api_keys: self
569                .api_keys
570                .iter()
571                .map(|k| ApiKeySummary {
572                    name: k.name.clone(),
573                    role: k.role.clone(),
574                    expires_at: k.expires_at.clone(),
575                })
576                .collect(),
577        }
578    }
579}
580
581/// Keyed rate limiter type (per source IP). Memory-bounded by
582/// [`RateLimitConfig::max_tracked_keys`] to defend against IP-spray `DoS`.
583pub(crate) type KeyedLimiter = BoundedKeyedLimiter<IpAddr>;
584
585/// Connection info for TLS connections, carrying the peer socket address
586/// and (when mTLS is configured) the verified client identity extracted
587/// from the peer certificate during the TLS handshake.
588///
589/// Defined as a local type so we can implement axum's `Connected` trait
590/// for our custom `TlsListener` without orphan rule issues. The `identity`
591/// field travels with the connection itself (via the wrapping IO type),
592/// so there is no shared map to race against, no port-reuse aliasing, and
593/// no eviction policy to maintain.
594#[derive(Clone, Debug)]
595#[non_exhaustive]
596pub(crate) struct TlsConnInfo {
597    /// Remote peer socket address.
598    pub addr: SocketAddr,
599    /// Verified mTLS client identity, if a client certificate was presented
600    /// and successfully extracted during the TLS handshake.
601    pub identity: Option<AuthIdentity>,
602}
603
604impl TlsConnInfo {
605    /// Construct a new [`TlsConnInfo`].
606    #[must_use]
607    pub(crate) const fn new(addr: SocketAddr, identity: Option<AuthIdentity>) -> Self {
608        Self { addr, identity }
609    }
610}
611
612/// Shared state for the auth middleware.
613///
614/// `api_keys` uses [`ArcSwap`] so the SIGHUP handler can atomically
615/// swap in a new key list without blocking in-flight requests.
616#[allow(
617    missing_debug_implementations,
618    reason = "contains governor RateLimiter and JwksCache without Debug impls"
619)]
620#[non_exhaustive]
621pub(crate) struct AuthState {
622    /// Active set of API keys (hot-swappable).
623    pub api_keys: ArcSwap<Vec<ApiKeyEntry>>,
624    /// Optional per-IP post-failure rate limiter (consulted *after* auth fails).
625    pub rate_limiter: Option<Arc<KeyedLimiter>>,
626    /// Optional per-IP pre-auth abuse gate (consulted *before* password-hash work).
627    /// mTLS-authenticated connections bypass this gate.
628    pub pre_auth_limiter: Option<Arc<KeyedLimiter>>,
629    #[cfg(feature = "oauth")]
630    /// Optional JWKS cache for OAuth JWT validation.
631    pub jwks_cache: Option<Arc<crate::oauth::JwksCache>>,
632    /// Tracks identity names that have already been logged at INFO level.
633    /// Subsequent auths for the same identity are logged at DEBUG.
634    pub seen_identities: Mutex<HashSet<String>>,
635    /// Lightweight in-memory auth success/failure counters for diagnostics.
636    pub counters: AuthCounters,
637}
638
639impl AuthState {
640    /// Atomically replace the API key list (lock-free, wait-free).
641    ///
642    /// New requests immediately see the updated keys.
643    /// In-flight requests that already loaded the old list finish
644    /// using it -- no torn reads.
645    pub(crate) fn reload_keys(&self, keys: Vec<ApiKeyEntry>) {
646        let count = keys.len();
647        self.api_keys.store(Arc::new(keys));
648        tracing::info!(keys = count, "API keys reloaded");
649    }
650
651    /// Snapshot auth counters for diagnostics and tests.
652    #[must_use]
653    pub(crate) fn counters_snapshot(&self) -> AuthCountersSnapshot {
654        self.counters.snapshot()
655    }
656
657    /// Produce the admin-endpoint list of API keys (metadata only, no hashes).
658    #[must_use]
659    pub(crate) fn api_key_summaries(&self) -> Vec<ApiKeySummary> {
660        self.api_keys
661            .load()
662            .iter()
663            .map(|k| ApiKeySummary {
664                name: k.name.clone(),
665                role: k.role.clone(),
666                expires_at: k.expires_at.clone(),
667            })
668            .collect()
669    }
670
671    /// Log auth success: INFO on first occurrence per identity, DEBUG after.
672    fn log_auth(&self, id: &AuthIdentity, method: &str) {
673        self.counters.record_success(id.method);
674        let first = self
675            .seen_identities
676            .lock()
677            .unwrap_or_else(std::sync::PoisonError::into_inner)
678            .insert(id.name.clone());
679        if first {
680            tracing::info!(name = %id.name, role = %id.role, "{method} authenticated");
681        } else {
682            tracing::debug!(name = %id.name, role = %id.role, "{method} authenticated");
683        }
684    }
685}
686
687/// Default auth rate limit: 30 attempts per minute per source IP.
688// SAFETY: unwrap() is safe - literal 30 is provably non-zero (const-evaluated).
689const DEFAULT_AUTH_RATE: NonZeroU32 = NonZeroU32::new(30).unwrap();
690
691/// Create a post-failure rate limiter from config.
692#[must_use]
693pub(crate) fn build_rate_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
694    let quota = governor::Quota::per_minute(
695        NonZeroU32::new(config.max_attempts_per_minute).unwrap_or(DEFAULT_AUTH_RATE),
696    );
697    Arc::new(BoundedKeyedLimiter::new(
698        quota,
699        config.max_tracked_keys,
700        config.idle_eviction,
701    ))
702}
703
704/// Create a pre-auth abuse-gate rate limiter from config.
705///
706/// Quota: `pre_auth_max_per_minute` if set, otherwise
707/// `max_attempts_per_minute * 10` (capped at `u32::MAX`). The 10× factor
708/// keeps the gate generous enough for honest retries while still bounding
709/// attacker CPU on Argon2 verification.
710#[must_use]
711pub(crate) fn build_pre_auth_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
712    let resolved = config.pre_auth_max_per_minute.unwrap_or_else(|| {
713        config
714            .max_attempts_per_minute
715            .saturating_mul(PRE_AUTH_DEFAULT_MULTIPLIER)
716    });
717    let quota =
718        governor::Quota::per_minute(NonZeroU32::new(resolved).unwrap_or(DEFAULT_PRE_AUTH_RATE));
719    Arc::new(BoundedKeyedLimiter::new(
720        quota,
721        config.max_tracked_keys,
722        config.idle_eviction,
723    ))
724}
725
726/// Default multiplier applied to `max_attempts_per_minute` when the operator
727/// does not set `pre_auth_max_per_minute` explicitly.
728const PRE_AUTH_DEFAULT_MULTIPLIER: u32 = 10;
729
730/// Default pre-auth abuse-gate rate (used only if both the configured value
731/// and the multiplied fallback are zero, which `NonZeroU32::new` rejects).
732// SAFETY: unwrap() is safe - literal 300 is provably non-zero (const-evaluated).
733const DEFAULT_PRE_AUTH_RATE: NonZeroU32 = NonZeroU32::new(300).unwrap();
734
735/// Parse an mTLS client certificate and extract an `AuthIdentity`.
736///
737/// Reads the Subject CN as the identity name. Falls back to the first
738/// DNS SAN if CN is absent. The role is taken from the `MtlsConfig`.
739#[must_use]
740pub fn extract_mtls_identity(cert_der: &[u8], default_role: &str) -> Option<AuthIdentity> {
741    let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
742
743    // Try CN from Subject first.
744    let cn = cert
745        .subject()
746        .iter_common_name()
747        .next()
748        .and_then(|attr| attr.as_str().ok())
749        .map(String::from);
750
751    // Fall back to first DNS SAN.
752    let name = cn.or_else(|| {
753        cert.subject_alternative_name()
754            .ok()
755            .flatten()
756            .and_then(|san| {
757                #[allow(clippy::wildcard_enum_match_arm)]
758                san.value.general_names.iter().find_map(|gn| match gn {
759                    GeneralName::DNSName(dns) => Some((*dns).to_owned()),
760                    _ => None,
761                })
762            })
763    })?;
764
765    // Reject identities with characters unsafe for logging and RBAC matching.
766    if !name
767        .chars()
768        .all(|c| c.is_alphanumeric() || matches!(c, '-' | '.' | '_' | '@'))
769    {
770        tracing::warn!(cn = %name, "mTLS identity rejected: invalid characters in CN/SAN");
771        return None;
772    }
773
774    Some(AuthIdentity {
775        name,
776        role: default_role.to_owned(),
777        method: AuthMethod::MtlsCertificate,
778        raw_token: None,
779        sub: None,
780    })
781}
782
783/// Verify a bearer token against configured API keys.
784///
785/// Argon2id verification is CPU-intensive, so this should be called via
786/// `spawn_blocking`. Returns the matching identity if the token is valid.
787///
788/// Iterates **all** keys to completion to prevent timing side-channels
789/// that would reveal how many keys exist or which slot matched.
790#[must_use]
791pub fn verify_bearer_token(token: &str, keys: &[ApiKeyEntry]) -> Option<AuthIdentity> {
792    let now = chrono::Utc::now();
793
794    // Always iterate ALL keys to completion to prevent timing side-channels
795    // that reveal how many keys exist or which position matched.
796    let mut result: Option<AuthIdentity> = None;
797
798    for key in keys {
799        // Check expiry
800        if let Some(ref expires) = key.expires_at
801            && let Ok(exp) = chrono::DateTime::parse_from_rfc3339(expires)
802            && exp < now
803        {
804            continue;
805        }
806
807        // Argon2id verification (constant-time internally).
808        // Keep the first match but continue checking remaining keys.
809        if result.is_none()
810            && let Ok(parsed_hash) = PasswordHash::new(&key.hash)
811            && Argon2::default()
812                .verify_password(token.as_bytes(), &parsed_hash)
813                .is_ok()
814        {
815            result = Some(AuthIdentity {
816                name: key.name.clone(),
817                role: key.role.clone(),
818                method: AuthMethod::BearerToken,
819                raw_token: None,
820                sub: None,
821            });
822        }
823    }
824    result
825}
826
827/// Generate a new API key: 256-bit random token + Argon2id hash.
828///
829/// Returns `(plaintext_token, argon2id_hash_phc_string)`.
830/// The plaintext is shown once to the user and never stored.
831///
832/// # Errors
833///
834/// Returns an error if salt encoding or Argon2id hashing fails
835/// (should not happen with valid inputs, but we avoid panicking).
836pub fn generate_api_key() -> Result<(String, String), McpxError> {
837    let mut token_bytes = [0u8; 32];
838    rand::fill(&mut token_bytes);
839    let token = URL_SAFE_NO_PAD.encode(token_bytes);
840
841    // Generate 16 random bytes for salt, encode as base64 for SaltString.
842    let mut salt_bytes = [0u8; 16];
843    rand::fill(&mut salt_bytes);
844    let salt = SaltString::encode_b64(&salt_bytes)
845        .map_err(|e| McpxError::Auth(format!("salt encoding failed: {e}")))?;
846    let hash = Argon2::default()
847        .hash_password(token.as_bytes(), &salt)
848        .map_err(|e| McpxError::Auth(format!("argon2id hashing failed: {e}")))?
849        .to_string();
850
851    Ok((token, hash))
852}
853
854fn build_www_authenticate_value(
855    advertise_resource_metadata: bool,
856    failure: AuthFailureClass,
857) -> String {
858    let (error, error_description) = failure.bearer_error();
859    if advertise_resource_metadata {
860        return format!(
861            "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\", error=\"{error}\", error_description=\"{error_description}\""
862        );
863    }
864    format!("Bearer error=\"{error}\", error_description=\"{error_description}\"")
865}
866
867fn auth_method_label(method: AuthMethod) -> &'static str {
868    match method {
869        AuthMethod::MtlsCertificate => "mTLS",
870        AuthMethod::BearerToken => "bearer token",
871        AuthMethod::OAuthJwt => "OAuth JWT",
872    }
873}
874
875#[cfg_attr(not(feature = "oauth"), allow(unused_variables))]
876fn unauthorized_response(state: &AuthState, failure_class: AuthFailureClass) -> Response {
877    #[cfg(feature = "oauth")]
878    let advertise_resource_metadata = state.jwks_cache.is_some();
879    #[cfg(not(feature = "oauth"))]
880    let advertise_resource_metadata = false;
881
882    let challenge = build_www_authenticate_value(advertise_resource_metadata, failure_class);
883    (
884        axum::http::StatusCode::UNAUTHORIZED,
885        [(header::WWW_AUTHENTICATE, challenge)],
886        failure_class.response_body(),
887    )
888        .into_response()
889}
890
891async fn authenticate_bearer_identity(
892    state: &AuthState,
893    token: &str,
894) -> Result<AuthIdentity, AuthFailureClass> {
895    let mut failure_class = AuthFailureClass::MissingCredential;
896
897    #[cfg(feature = "oauth")]
898    if let Some(ref cache) = state.jwks_cache
899        && crate::oauth::looks_like_jwt(token)
900    {
901        match cache.validate_token_with_reason(token).await {
902            Ok(mut id) => {
903                id.raw_token = Some(SecretString::from(token.to_owned()));
904                return Ok(id);
905            }
906            Err(crate::oauth::JwtValidationFailure::Expired) => {
907                failure_class = AuthFailureClass::ExpiredCredential;
908            }
909            Err(crate::oauth::JwtValidationFailure::Invalid) => {
910                failure_class = AuthFailureClass::InvalidCredential;
911            }
912        }
913    }
914
915    let token = token.to_owned();
916    let keys = state.api_keys.load_full(); // Arc clone, lock-free
917
918    // Argon2id is CPU-bound - offload to blocking thread pool.
919    let identity = tokio::task::spawn_blocking(move || verify_bearer_token(&token, &keys))
920        .await
921        .ok()
922        .flatten();
923
924    if let Some(id) = identity {
925        return Ok(id);
926    }
927
928    if failure_class == AuthFailureClass::MissingCredential {
929        failure_class = AuthFailureClass::InvalidCredential;
930    }
931
932    Err(failure_class)
933}
934
935/// Consult the pre-auth abuse gate for the given peer.
936///
937/// Returns `Some(response)` if the request should be rejected (limiter
938/// configured AND quota exhausted for this source IP). Returns `None`
939/// otherwise (limiter absent, peer address unknown, or quota available),
940/// in which case the caller should proceed with credential verification.
941///
942/// Side effects on rejection: increments the `pre_auth_gate` failure
943/// counter and emits a warn-level log. mTLS-authenticated requests must
944/// be admitted by the caller *before* invoking this helper.
945fn pre_auth_gate(state: &AuthState, peer_addr: Option<SocketAddr>) -> Option<Response> {
946    let limiter = state.pre_auth_limiter.as_ref()?;
947    let addr = peer_addr?;
948    if limiter.check_key(&addr.ip()).is_ok() {
949        return None;
950    }
951    state.counters.record_failure(AuthFailureClass::PreAuthGate);
952    tracing::warn!(
953        ip = %addr.ip(),
954        "auth rate limited by pre-auth gate (request rejected before credential verification)"
955    );
956    Some(
957        McpxError::RateLimited("too many unauthenticated requests from this source".into())
958            .into_response(),
959    )
960}
961
962/// Axum middleware that enforces authentication.
963///
964/// Tries authentication methods in priority order:
965/// 1. mTLS client certificate identity (populated by TLS acceptor)
966/// 2. Bearer token from `Authorization` header
967///
968/// Failed authentication attempts are rate-limited per source IP.
969/// Successful authentications do not consume rate limit budget.
970pub(crate) async fn auth_middleware(
971    state: Arc<AuthState>,
972    req: Request<Body>,
973    next: Next,
974) -> Response {
975    // Extract peer address (and any mTLS identity) from ConnectInfo.
976    // Plain TCP: ConnectInfo<SocketAddr>. TLS / mTLS: ConnectInfo<TlsConnInfo>,
977    // which carries the verified identity directly on the connection — no
978    // shared map, no port-reuse aliasing.
979    let tls_info = req.extensions().get::<ConnectInfo<TlsConnInfo>>().cloned();
980    let peer_addr = req
981        .extensions()
982        .get::<ConnectInfo<SocketAddr>>()
983        .map(|ci| ci.0)
984        .or_else(|| tls_info.as_ref().map(|ci| ci.0.addr));
985
986    // 1. Try mTLS identity (extracted by the TLS acceptor during handshake
987    //    and attached to the connection itself).
988    //
989    //    mTLS connections bypass the pre-auth abuse gate below: the TLS
990    //    handshake already performed expensive crypto with a verified peer,
991    //    so we trust them not to be a CPU-spray attacker.
992    if let Some(id) = tls_info.and_then(|ci| ci.0.identity) {
993        state.log_auth(&id, "mTLS");
994        let mut req = req;
995        req.extensions_mut().insert(id);
996        return next.run(req).await;
997    }
998
999    // 2. Pre-auth abuse gate: rejects CPU-spray attacks BEFORE the Argon2id
1000    //    verification path runs. Keyed by source IP. mTLS connections (above)
1001    //    are exempt; this gate only protects the bearer/JWT verification path.
1002    if let Some(blocked) = pre_auth_gate(&state, peer_addr) {
1003        return blocked;
1004    }
1005
1006    let failure_class = if let Some(value) = req.headers().get(header::AUTHORIZATION) {
1007        match value.to_str().ok().and_then(|v| v.strip_prefix("Bearer ")) {
1008            Some(token) => match authenticate_bearer_identity(&state, token).await {
1009                Ok(id) => {
1010                    state.log_auth(&id, auth_method_label(id.method));
1011                    let mut req = req;
1012                    req.extensions_mut().insert(id);
1013                    return next.run(req).await;
1014                }
1015                Err(class) => class,
1016            },
1017            None => AuthFailureClass::InvalidCredential,
1018        }
1019    } else {
1020        AuthFailureClass::MissingCredential
1021    };
1022
1023    tracing::warn!(failure_class = %failure_class.as_str(), "auth failed");
1024
1025    // Rate limit check (applied after auth failure only).
1026    // Successful authentications do not consume rate limit budget.
1027    if let (Some(limiter), Some(addr)) = (&state.rate_limiter, peer_addr)
1028        && limiter.check_key(&addr.ip()).is_err()
1029    {
1030        state.counters.record_failure(AuthFailureClass::RateLimited);
1031        tracing::warn!(ip = %addr.ip(), "auth rate limited after repeated failures");
1032        return McpxError::RateLimited("too many failed authentication attempts".into())
1033            .into_response();
1034    }
1035
1036    state.counters.record_failure(failure_class);
1037    unauthorized_response(&state, failure_class)
1038}
1039
1040#[cfg(test)]
1041mod tests {
1042    use super::*;
1043
1044    #[test]
1045    fn generate_and_verify_api_key() {
1046        let (token, hash) = generate_api_key().unwrap();
1047
1048        // Token is 43 chars (256-bit base64url, no padding)
1049        assert_eq!(token.len(), 43);
1050
1051        // Hash is a valid PHC string
1052        assert!(hash.starts_with("$argon2id$"));
1053
1054        // Verification succeeds with correct token
1055        let keys = vec![ApiKeyEntry {
1056            name: "test".into(),
1057            hash,
1058            role: "viewer".into(),
1059            expires_at: None,
1060        }];
1061        let id = verify_bearer_token(&token, &keys);
1062        assert!(id.is_some());
1063        let id = id.unwrap();
1064        assert_eq!(id.name, "test");
1065        assert_eq!(id.role, "viewer");
1066        assert_eq!(id.method, AuthMethod::BearerToken);
1067    }
1068
1069    #[test]
1070    fn wrong_token_rejected() {
1071        let (_token, hash) = generate_api_key().unwrap();
1072        let keys = vec![ApiKeyEntry {
1073            name: "test".into(),
1074            hash,
1075            role: "viewer".into(),
1076            expires_at: None,
1077        }];
1078        assert!(verify_bearer_token("wrong-token", &keys).is_none());
1079    }
1080
1081    #[test]
1082    fn expired_key_rejected() {
1083        let (token, hash) = generate_api_key().unwrap();
1084        let keys = vec![ApiKeyEntry {
1085            name: "test".into(),
1086            hash,
1087            role: "viewer".into(),
1088            expires_at: Some("2020-01-01T00:00:00Z".into()),
1089        }];
1090        assert!(verify_bearer_token(&token, &keys).is_none());
1091    }
1092
1093    #[test]
1094    fn future_expiry_accepted() {
1095        let (token, hash) = generate_api_key().unwrap();
1096        let keys = vec![ApiKeyEntry {
1097            name: "test".into(),
1098            hash,
1099            role: "viewer".into(),
1100            expires_at: Some("2099-01-01T00:00:00Z".into()),
1101        }];
1102        assert!(verify_bearer_token(&token, &keys).is_some());
1103    }
1104
1105    #[test]
1106    fn multiple_keys_first_match_wins() {
1107        let (token, hash) = generate_api_key().unwrap();
1108        let keys = vec![
1109            ApiKeyEntry {
1110                name: "wrong".into(),
1111                hash: "$argon2id$v=19$m=19456,t=2,p=1$invalid$invalid".into(),
1112                role: "ops".into(),
1113                expires_at: None,
1114            },
1115            ApiKeyEntry {
1116                name: "correct".into(),
1117                hash,
1118                role: "deploy".into(),
1119                expires_at: None,
1120            },
1121        ];
1122        let id = verify_bearer_token(&token, &keys).unwrap();
1123        assert_eq!(id.name, "correct");
1124        assert_eq!(id.role, "deploy");
1125    }
1126
1127    #[test]
1128    fn rate_limiter_allows_within_quota() {
1129        let config = RateLimitConfig {
1130            max_attempts_per_minute: 5,
1131            pre_auth_max_per_minute: None,
1132            ..Default::default()
1133        };
1134        let limiter = build_rate_limiter(&config);
1135        let ip: IpAddr = "10.0.0.1".parse().unwrap();
1136
1137        // First 5 should succeed.
1138        for _ in 0..5 {
1139            assert!(limiter.check_key(&ip).is_ok());
1140        }
1141        // 6th should fail.
1142        assert!(limiter.check_key(&ip).is_err());
1143    }
1144
1145    #[test]
1146    fn rate_limiter_separate_ips() {
1147        let config = RateLimitConfig {
1148            max_attempts_per_minute: 2,
1149            pre_auth_max_per_minute: None,
1150            ..Default::default()
1151        };
1152        let limiter = build_rate_limiter(&config);
1153        let ip1: IpAddr = "10.0.0.1".parse().unwrap();
1154        let ip2: IpAddr = "10.0.0.2".parse().unwrap();
1155
1156        // Exhaust ip1's quota.
1157        assert!(limiter.check_key(&ip1).is_ok());
1158        assert!(limiter.check_key(&ip1).is_ok());
1159        assert!(limiter.check_key(&ip1).is_err());
1160
1161        // ip2 should still have quota.
1162        assert!(limiter.check_key(&ip2).is_ok());
1163    }
1164
1165    #[test]
1166    fn extract_mtls_identity_from_cn() {
1167        // Generate a cert with explicit CN.
1168        let mut params = rcgen::CertificateParams::new(vec!["test-client.local".into()]).unwrap();
1169        params.distinguished_name = rcgen::DistinguishedName::new();
1170        params
1171            .distinguished_name
1172            .push(rcgen::DnType::CommonName, "test-client");
1173        let cert = params
1174            .self_signed(&rcgen::KeyPair::generate().unwrap())
1175            .unwrap();
1176        let der = cert.der();
1177
1178        let id = extract_mtls_identity(der, "ops").unwrap();
1179        assert_eq!(id.name, "test-client");
1180        assert_eq!(id.role, "ops");
1181        assert_eq!(id.method, AuthMethod::MtlsCertificate);
1182    }
1183
1184    #[test]
1185    fn extract_mtls_identity_falls_back_to_san() {
1186        // Cert with no CN but has a DNS SAN.
1187        let mut params =
1188            rcgen::CertificateParams::new(vec!["san-only.example.com".into()]).unwrap();
1189        params.distinguished_name = rcgen::DistinguishedName::new();
1190        // No CN set - should fall back to DNS SAN.
1191        let cert = params
1192            .self_signed(&rcgen::KeyPair::generate().unwrap())
1193            .unwrap();
1194        let der = cert.der();
1195
1196        let id = extract_mtls_identity(der, "viewer").unwrap();
1197        assert_eq!(id.name, "san-only.example.com");
1198        assert_eq!(id.role, "viewer");
1199    }
1200
1201    #[test]
1202    fn extract_mtls_identity_invalid_der() {
1203        assert!(extract_mtls_identity(b"not-a-cert", "viewer").is_none());
1204    }
1205
1206    // -- auth_middleware integration tests --
1207
1208    use axum::{
1209        body::Body,
1210        http::{Request, StatusCode},
1211    };
1212    use tower::ServiceExt as _;
1213
1214    fn auth_router(state: Arc<AuthState>) -> axum::Router {
1215        axum::Router::new()
1216            .route("/mcp", axum::routing::post(|| async { "ok" }))
1217            .layer(axum::middleware::from_fn(move |req, next| {
1218                let s = Arc::clone(&state);
1219                auth_middleware(s, req, next)
1220            }))
1221    }
1222
1223    fn test_auth_state(keys: Vec<ApiKeyEntry>) -> Arc<AuthState> {
1224        Arc::new(AuthState {
1225            api_keys: ArcSwap::new(Arc::new(keys)),
1226            rate_limiter: None,
1227            pre_auth_limiter: None,
1228            #[cfg(feature = "oauth")]
1229            jwks_cache: None,
1230            seen_identities: Mutex::new(HashSet::new()),
1231            counters: AuthCounters::default(),
1232        })
1233    }
1234
1235    #[tokio::test]
1236    async fn middleware_rejects_no_credentials() {
1237        let state = test_auth_state(vec![]);
1238        let app = auth_router(Arc::clone(&state));
1239        let req = Request::builder()
1240            .method(axum::http::Method::POST)
1241            .uri("/mcp")
1242            .body(Body::empty())
1243            .unwrap();
1244        let resp = app.oneshot(req).await.unwrap();
1245        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1246        let challenge = resp
1247            .headers()
1248            .get(header::WWW_AUTHENTICATE)
1249            .unwrap()
1250            .to_str()
1251            .unwrap();
1252        assert!(challenge.contains("error=\"invalid_request\""));
1253
1254        let counters = state.counters_snapshot();
1255        assert_eq!(counters.failure_missing_credential, 1);
1256    }
1257
1258    #[tokio::test]
1259    async fn middleware_accepts_valid_bearer() {
1260        let (token, hash) = generate_api_key().unwrap();
1261        let keys = vec![ApiKeyEntry {
1262            name: "test-key".into(),
1263            hash,
1264            role: "ops".into(),
1265            expires_at: None,
1266        }];
1267        let state = test_auth_state(keys);
1268        let app = auth_router(Arc::clone(&state));
1269        let req = Request::builder()
1270            .method(axum::http::Method::POST)
1271            .uri("/mcp")
1272            .header("authorization", format!("Bearer {token}"))
1273            .body(Body::empty())
1274            .unwrap();
1275        let resp = app.oneshot(req).await.unwrap();
1276        assert_eq!(resp.status(), StatusCode::OK);
1277
1278        let counters = state.counters_snapshot();
1279        assert_eq!(counters.success_bearer, 1);
1280    }
1281
1282    #[tokio::test]
1283    async fn middleware_rejects_wrong_bearer() {
1284        let (_token, hash) = generate_api_key().unwrap();
1285        let keys = vec![ApiKeyEntry {
1286            name: "test-key".into(),
1287            hash,
1288            role: "ops".into(),
1289            expires_at: None,
1290        }];
1291        let state = test_auth_state(keys);
1292        let app = auth_router(Arc::clone(&state));
1293        let req = Request::builder()
1294            .method(axum::http::Method::POST)
1295            .uri("/mcp")
1296            .header("authorization", "Bearer wrong-token-here")
1297            .body(Body::empty())
1298            .unwrap();
1299        let resp = app.oneshot(req).await.unwrap();
1300        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1301        let challenge = resp
1302            .headers()
1303            .get(header::WWW_AUTHENTICATE)
1304            .unwrap()
1305            .to_str()
1306            .unwrap();
1307        assert!(challenge.contains("error=\"invalid_token\""));
1308
1309        let counters = state.counters_snapshot();
1310        assert_eq!(counters.failure_invalid_credential, 1);
1311    }
1312
1313    #[tokio::test]
1314    async fn middleware_rate_limits() {
1315        let state = Arc::new(AuthState {
1316            api_keys: ArcSwap::new(Arc::new(vec![])),
1317            rate_limiter: Some(build_rate_limiter(&RateLimitConfig {
1318                max_attempts_per_minute: 1,
1319                pre_auth_max_per_minute: None,
1320                ..Default::default()
1321            })),
1322            pre_auth_limiter: None,
1323            #[cfg(feature = "oauth")]
1324            jwks_cache: None,
1325            seen_identities: Mutex::new(HashSet::new()),
1326            counters: AuthCounters::default(),
1327        });
1328        let app = auth_router(state);
1329
1330        // First request: UNAUTHORIZED (no credentials, but not rate limited)
1331        let req = Request::builder()
1332            .method(axum::http::Method::POST)
1333            .uri("/mcp")
1334            .body(Body::empty())
1335            .unwrap();
1336        let resp = app.clone().oneshot(req).await.unwrap();
1337        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1338
1339        // Second request from same "IP" (no ConnectInfo in test, so peer_addr is None
1340        // and rate limiter won't fire). That's expected -- rate limiting requires
1341        // ConnectInfo which isn't available in unit tests without a real server.
1342        // This test verifies the middleware wiring doesn't panic.
1343    }
1344
1345    /// Verify that rate limit semantics: only failed auth attempts consume budget.
1346    ///
1347    /// This is a unit test of the limiter behavior. The middleware integration
1348    /// is that on auth failure, `check_key` is called; on auth success, it is NOT.
1349    /// Full e2e tests verify the middleware routing but require `ConnectInfo`.
1350    #[test]
1351    fn rate_limit_semantics_failed_only() {
1352        let config = RateLimitConfig {
1353            max_attempts_per_minute: 3,
1354            pre_auth_max_per_minute: None,
1355            ..Default::default()
1356        };
1357        let limiter = build_rate_limiter(&config);
1358        let ip: IpAddr = "192.168.1.100".parse().unwrap();
1359
1360        // Simulate: 3 failed attempts should exhaust quota.
1361        assert!(
1362            limiter.check_key(&ip).is_ok(),
1363            "failure 1 should be allowed"
1364        );
1365        assert!(
1366            limiter.check_key(&ip).is_ok(),
1367            "failure 2 should be allowed"
1368        );
1369        assert!(
1370            limiter.check_key(&ip).is_ok(),
1371            "failure 3 should be allowed"
1372        );
1373        assert!(
1374            limiter.check_key(&ip).is_err(),
1375            "failure 4 should be blocked"
1376        );
1377
1378        // In the actual middleware flow:
1379        // - Successful auth: verify_bearer_token returns Some, we return early
1380        //   WITHOUT calling check_key, so no budget consumed.
1381        // - Failed auth: verify_bearer_token returns None, we call check_key
1382        //   THEN return 401, so budget is consumed.
1383        //
1384        // This means N successful requests followed by M failed requests
1385        // will only count M toward the rate limit, not N+M.
1386    }
1387
1388    // -- pre-auth abuse gate (H-S1) --
1389
1390    /// The pre-auth gate must default to ~10x the post-failure quota so honest
1391    /// retry storms never trip it but a Argon2-spray attacker is throttled.
1392    #[test]
1393    fn pre_auth_default_multiplier_is_10x() {
1394        let config = RateLimitConfig {
1395            max_attempts_per_minute: 5,
1396            pre_auth_max_per_minute: None,
1397            ..Default::default()
1398        };
1399        let limiter = build_pre_auth_limiter(&config);
1400        let ip: IpAddr = "10.0.0.1".parse().unwrap();
1401
1402        // Quota should be 50 (5 * 10), not 5. We expect the first 50 to pass.
1403        for i in 0..50 {
1404            assert!(
1405                limiter.check_key(&ip).is_ok(),
1406                "pre-auth attempt {i} (of expected 50) should be allowed under default 10x multiplier"
1407            );
1408        }
1409        // The 51st attempt must be blocked: confirms quota is bounded, not infinite.
1410        assert!(
1411            limiter.check_key(&ip).is_err(),
1412            "pre-auth attempt 51 should be blocked (quota is 50, not unbounded)"
1413        );
1414    }
1415
1416    /// An explicit `pre_auth_max_per_minute` override must win over the
1417    /// 10x-multiplier default.
1418    #[test]
1419    fn pre_auth_explicit_override_wins() {
1420        let config = RateLimitConfig {
1421            max_attempts_per_minute: 100,     // would default to 1000 pre-auth quota
1422            pre_auth_max_per_minute: Some(2), // but operator caps at 2
1423            ..Default::default()
1424        };
1425        let limiter = build_pre_auth_limiter(&config);
1426        let ip: IpAddr = "10.0.0.2".parse().unwrap();
1427
1428        assert!(limiter.check_key(&ip).is_ok(), "attempt 1 allowed");
1429        assert!(limiter.check_key(&ip).is_ok(), "attempt 2 allowed");
1430        assert!(
1431            limiter.check_key(&ip).is_err(),
1432            "attempt 3 must be blocked (explicit override of 2 wins over 10x default of 1000)"
1433        );
1434    }
1435
1436    /// End-to-end: the pre-auth gate must reject before the bearer-verification
1437    /// path runs. We exhaust the gate's quota (Some(1)) with one bad-bearer
1438    /// request, then the second request must be rejected with 429 + the
1439    /// `pre_auth_gate` failure counter incremented (NOT
1440    /// `failure_invalid_credential`, which would prove Argon2 ran).
1441    #[tokio::test]
1442    async fn pre_auth_gate_blocks_before_argon2_verification() {
1443        let (_token, hash) = generate_api_key().unwrap();
1444        let keys = vec![ApiKeyEntry {
1445            name: "test-key".into(),
1446            hash,
1447            role: "ops".into(),
1448            expires_at: None,
1449        }];
1450        let config = RateLimitConfig {
1451            max_attempts_per_minute: 100,
1452            pre_auth_max_per_minute: Some(1),
1453            ..Default::default()
1454        };
1455        let state = Arc::new(AuthState {
1456            api_keys: ArcSwap::new(Arc::new(keys)),
1457            rate_limiter: None,
1458            pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1459            #[cfg(feature = "oauth")]
1460            jwks_cache: None,
1461            seen_identities: Mutex::new(HashSet::new()),
1462            counters: AuthCounters::default(),
1463        });
1464        let app = auth_router(Arc::clone(&state));
1465        let peer: SocketAddr = "10.0.0.10:54321".parse().unwrap();
1466
1467        // First bad-bearer request: gate has quota, bearer verification runs,
1468        // returns 401 (invalid credential).
1469        let mut req1 = Request::builder()
1470            .method(axum::http::Method::POST)
1471            .uri("/mcp")
1472            .header("authorization", "Bearer obviously-not-a-real-token")
1473            .body(Body::empty())
1474            .unwrap();
1475        req1.extensions_mut().insert(ConnectInfo(peer));
1476        let resp1 = app.clone().oneshot(req1).await.unwrap();
1477        assert_eq!(
1478            resp1.status(),
1479            StatusCode::UNAUTHORIZED,
1480            "first attempt: gate has quota, falls through to bearer auth which fails with 401"
1481        );
1482
1483        // Second bad-bearer request from same IP: gate quota exhausted, must
1484        // reject with 429 BEFORE the Argon2 verification path runs.
1485        let mut req2 = Request::builder()
1486            .method(axum::http::Method::POST)
1487            .uri("/mcp")
1488            .header("authorization", "Bearer also-not-a-real-token")
1489            .body(Body::empty())
1490            .unwrap();
1491        req2.extensions_mut().insert(ConnectInfo(peer));
1492        let resp2 = app.oneshot(req2).await.unwrap();
1493        assert_eq!(
1494            resp2.status(),
1495            StatusCode::TOO_MANY_REQUESTS,
1496            "second attempt from same IP: pre-auth gate must reject with 429"
1497        );
1498
1499        let counters = state.counters_snapshot();
1500        assert_eq!(
1501            counters.failure_pre_auth_gate, 1,
1502            "exactly one request must have been rejected by the pre-auth gate"
1503        );
1504        // Critical: Argon2 verification must NOT have run on the gated request.
1505        // The first request's 401 increments `failure_invalid_credential` to 1;
1506        // the second (gated) request must NOT increment it further.
1507        assert_eq!(
1508            counters.failure_invalid_credential, 1,
1509            "bearer verification must run exactly once (only the un-gated first request)"
1510        );
1511    }
1512
1513    /// mTLS-authenticated requests must bypass the pre-auth gate entirely.
1514    /// The TLS handshake already performed expensive crypto with a verified
1515    /// peer, so mTLS callers should never be throttled by this gate.
1516    ///
1517    /// Setup: a pre-auth gate with quota 1 (very tight). Submit two mTLS
1518    /// requests in quick succession from the same IP. Both must succeed.
1519    #[tokio::test]
1520    async fn pre_auth_gate_does_not_throttle_mtls() {
1521        let config = RateLimitConfig {
1522            max_attempts_per_minute: 100,
1523            pre_auth_max_per_minute: Some(1), // tight: would block 2nd plain request
1524            ..Default::default()
1525        };
1526        let state = Arc::new(AuthState {
1527            api_keys: ArcSwap::new(Arc::new(vec![])),
1528            rate_limiter: None,
1529            pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1530            #[cfg(feature = "oauth")]
1531            jwks_cache: None,
1532            seen_identities: Mutex::new(HashSet::new()),
1533            counters: AuthCounters::default(),
1534        });
1535        let app = auth_router(Arc::clone(&state));
1536        let peer: SocketAddr = "10.0.0.20:54321".parse().unwrap();
1537        let identity = AuthIdentity {
1538            name: "cn=test-client".into(),
1539            role: "viewer".into(),
1540            method: AuthMethod::MtlsCertificate,
1541            raw_token: None,
1542            sub: None,
1543        };
1544        let tls_info = TlsConnInfo::new(peer, Some(identity));
1545
1546        for i in 0..3 {
1547            let mut req = Request::builder()
1548                .method(axum::http::Method::POST)
1549                .uri("/mcp")
1550                .body(Body::empty())
1551                .unwrap();
1552            req.extensions_mut().insert(ConnectInfo(tls_info.clone()));
1553            let resp = app.clone().oneshot(req).await.unwrap();
1554            assert_eq!(
1555                resp.status(),
1556                StatusCode::OK,
1557                "mTLS request {i} must succeed: pre-auth gate must not apply to mTLS callers"
1558            );
1559        }
1560
1561        let counters = state.counters_snapshot();
1562        assert_eq!(
1563            counters.failure_pre_auth_gate, 0,
1564            "pre-auth gate counter must remain at zero: mTLS bypasses the gate"
1565        );
1566        assert_eq!(
1567            counters.success_mtls, 3,
1568            "all three mTLS requests must have been counted as successful"
1569        );
1570    }
1571}