Skip to main content

rmcp_server_kit/
auth.rs

1//! Authentication middleware for MCP servers.
2//!
3//! Supports multiple authentication methods tried in priority order:
4//! 1. mTLS client certificate (if configured and peer cert present)
5//! 2. Bearer token (API key) with Argon2id hash verification
6//!
7//! Includes per-source-IP rate limiting on authentication attempts.
8
9use std::{
10    collections::HashSet,
11    net::{IpAddr, SocketAddr},
12    num::NonZeroU32,
13    path::PathBuf,
14    sync::{
15        Arc, Mutex,
16        atomic::{AtomicU64, Ordering},
17    },
18    time::Duration,
19};
20
21use arc_swap::ArcSwap;
22use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
23use axum::{
24    body::Body,
25    extract::ConnectInfo,
26    http::{Request, header},
27    middleware::Next,
28    response::{IntoResponse, Response},
29};
30use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
31use secrecy::SecretString;
32use serde::Deserialize;
33use x509_parser::prelude::*;
34
35use crate::{bounded_limiter::BoundedKeyedLimiter, error::McpxError};
36
37/// Identity of an authenticated caller.
38///
39/// The [`Debug`] impl is **manually written** to redact the raw bearer token
40/// and the JWT `sub` claim. This prevents accidental disclosure if an
41/// `AuthIdentity` is ever logged via `tracing::debug!(?identity, …)` or
42/// `format!("{identity:?}")`. Only `name`, `role`, and `method` are printed
43/// in the clear; `raw_token` and `sub` are rendered as `<redacted>` /
44/// `<present>` / `<none>` markers.
45#[derive(Clone)]
46#[non_exhaustive]
47pub struct AuthIdentity {
48    /// Human-readable identity name (e.g. API key label or cert CN).
49    pub name: String,
50    /// RBAC role associated with this identity.
51    pub role: String,
52    /// Which authentication mechanism produced this identity.
53    pub method: AuthMethod,
54    /// Raw bearer token from the `Authorization` header, wrapped in
55    /// [`SecretString`] so it is never accidentally logged or serialized.
56    /// Present for OAuth JWT; `None` for mTLS and API-key auth.
57    /// Tool handlers use this for downstream token passthrough via
58    /// [`crate::rbac::current_token`].
59    pub raw_token: Option<SecretString>,
60    /// JWT `sub` claim (stable user identifier, e.g. Keycloak UUID).
61    /// Used for token store keying. `None` for non-JWT auth.
62    pub sub: Option<String>,
63}
64
65impl std::fmt::Debug for AuthIdentity {
66    /// Redacts `raw_token` and `sub` to prevent secret leakage via
67    /// `format!("{:?}")` or `tracing::debug!(?identity)`.
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.debug_struct("AuthIdentity")
70            .field("name", &self.name)
71            .field("role", &self.role)
72            .field("method", &self.method)
73            .field(
74                "raw_token",
75                &if self.raw_token.is_some() {
76                    "<redacted>"
77                } else {
78                    "<none>"
79                },
80            )
81            .field(
82                "sub",
83                &if self.sub.is_some() {
84                    "<redacted>"
85                } else {
86                    "<none>"
87                },
88            )
89            .finish()
90    }
91}
92
93/// How the caller authenticated.
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95#[non_exhaustive]
96pub enum AuthMethod {
97    /// Bearer API key (Argon2id-hashed, configured statically).
98    BearerToken,
99    /// Mutual TLS client certificate.
100    MtlsCertificate,
101    /// OAuth 2.1 JWT bearer token (validated via JWKS).
102    OAuthJwt,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106enum AuthFailureClass {
107    MissingCredential,
108    InvalidCredential,
109    #[cfg_attr(not(feature = "oauth"), allow(dead_code))]
110    ExpiredCredential,
111    /// Source IP exceeded the post-failure backoff limit.
112    RateLimited,
113    /// Source IP exceeded the pre-auth abuse gate (rejected before any
114    /// password-hash work — see [`AuthState::pre_auth_limiter`]).
115    PreAuthGate,
116}
117
118impl AuthFailureClass {
119    fn as_str(self) -> &'static str {
120        match self {
121            Self::MissingCredential => "missing_credential",
122            Self::InvalidCredential => "invalid_credential",
123            Self::ExpiredCredential => "expired_credential",
124            Self::RateLimited => "rate_limited",
125            Self::PreAuthGate => "pre_auth_gate",
126        }
127    }
128
129    fn bearer_error(self) -> (&'static str, &'static str) {
130        match self {
131            Self::MissingCredential => (
132                "invalid_request",
133                "missing bearer token or mTLS client certificate",
134            ),
135            Self::InvalidCredential => ("invalid_token", "token is invalid"),
136            Self::ExpiredCredential => ("invalid_token", "token is expired"),
137            Self::RateLimited => ("invalid_request", "too many failed authentication attempts"),
138            Self::PreAuthGate => (
139                "invalid_request",
140                "too many unauthenticated requests from this source",
141            ),
142        }
143    }
144
145    fn response_body(self) -> &'static str {
146        match self {
147            Self::MissingCredential => "unauthorized: missing credential",
148            Self::InvalidCredential => "unauthorized: invalid credential",
149            Self::ExpiredCredential => "unauthorized: expired credential",
150            Self::RateLimited => "rate limited",
151            Self::PreAuthGate => "rate limited (pre-auth)",
152        }
153    }
154}
155
156/// Snapshot of authentication success/failure counters.
157#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
158#[non_exhaustive]
159pub struct AuthCountersSnapshot {
160    /// Successful mTLS authentications.
161    pub success_mtls: u64,
162    /// Successful bearer-token authentications.
163    pub success_bearer: u64,
164    /// Successful OAuth JWT authentications.
165    pub success_oauth_jwt: u64,
166    /// Failures because no credential was presented.
167    pub failure_missing_credential: u64,
168    /// Failures because the credential was malformed or wrong.
169    pub failure_invalid_credential: u64,
170    /// Failures because the credential had expired.
171    pub failure_expired_credential: u64,
172    /// Failures because the source IP was rate-limited (post-failure backoff).
173    pub failure_rate_limited: u64,
174    /// Failures because the source IP exceeded the pre-auth abuse gate.
175    /// These never reach the password-hash verification path.
176    pub failure_pre_auth_gate: u64,
177}
178
179/// Internal atomic counters backing [`AuthCountersSnapshot`].
180#[derive(Debug, Default)]
181pub(crate) struct AuthCounters {
182    success_mtls: AtomicU64,
183    success_bearer: AtomicU64,
184    success_oauth_jwt: AtomicU64,
185    failure_missing_credential: AtomicU64,
186    failure_invalid_credential: AtomicU64,
187    failure_expired_credential: AtomicU64,
188    failure_rate_limited: AtomicU64,
189    failure_pre_auth_gate: AtomicU64,
190}
191
192impl AuthCounters {
193    fn record_success(&self, method: AuthMethod) {
194        match method {
195            AuthMethod::MtlsCertificate => {
196                self.success_mtls.fetch_add(1, Ordering::Relaxed);
197            }
198            AuthMethod::BearerToken => {
199                self.success_bearer.fetch_add(1, Ordering::Relaxed);
200            }
201            AuthMethod::OAuthJwt => {
202                self.success_oauth_jwt.fetch_add(1, Ordering::Relaxed);
203            }
204        }
205    }
206
207    fn record_failure(&self, class: AuthFailureClass) {
208        match class {
209            AuthFailureClass::MissingCredential => {
210                self.failure_missing_credential
211                    .fetch_add(1, Ordering::Relaxed);
212            }
213            AuthFailureClass::InvalidCredential => {
214                self.failure_invalid_credential
215                    .fetch_add(1, Ordering::Relaxed);
216            }
217            AuthFailureClass::ExpiredCredential => {
218                self.failure_expired_credential
219                    .fetch_add(1, Ordering::Relaxed);
220            }
221            AuthFailureClass::RateLimited => {
222                self.failure_rate_limited.fetch_add(1, Ordering::Relaxed);
223            }
224            AuthFailureClass::PreAuthGate => {
225                self.failure_pre_auth_gate.fetch_add(1, Ordering::Relaxed);
226            }
227        }
228    }
229
230    fn snapshot(&self) -> AuthCountersSnapshot {
231        AuthCountersSnapshot {
232            success_mtls: self.success_mtls.load(Ordering::Relaxed),
233            success_bearer: self.success_bearer.load(Ordering::Relaxed),
234            success_oauth_jwt: self.success_oauth_jwt.load(Ordering::Relaxed),
235            failure_missing_credential: self.failure_missing_credential.load(Ordering::Relaxed),
236            failure_invalid_credential: self.failure_invalid_credential.load(Ordering::Relaxed),
237            failure_expired_credential: self.failure_expired_credential.load(Ordering::Relaxed),
238            failure_rate_limited: self.failure_rate_limited.load(Ordering::Relaxed),
239            failure_pre_auth_gate: self.failure_pre_auth_gate.load(Ordering::Relaxed),
240        }
241    }
242}
243
244/// A single API key entry (stored as Argon2id hash in config).
245///
246/// The [`Debug`] impl is **manually written** to redact the Argon2id hash.
247/// Although the hash is not directly reversible, treating it as a secret
248/// prevents offline brute-force attempts from leaked logs and matches the
249/// defense-in-depth posture used for [`AuthIdentity`].
250#[derive(Clone, Deserialize)]
251#[non_exhaustive]
252pub struct ApiKeyEntry {
253    /// Human-readable key label (used in logs and audit records).
254    pub name: String,
255    /// Argon2id hash of the token (PHC string format).
256    pub hash: String,
257    /// RBAC role granted when this key authenticates successfully.
258    pub role: String,
259    /// Optional expiry in RFC 3339 format.
260    pub expires_at: Option<String>,
261}
262
263impl std::fmt::Debug for ApiKeyEntry {
264    /// Redacts the Argon2id `hash` to keep it out of logs, panic backtraces,
265    /// and admin-endpoint responses that might `format!("{:?}", …)` an entry.
266    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
267        f.debug_struct("ApiKeyEntry")
268            .field("name", &self.name)
269            .field("hash", &"<redacted>")
270            .field("role", &self.role)
271            .field("expires_at", &self.expires_at)
272            .finish()
273    }
274}
275
276impl ApiKeyEntry {
277    /// Create a new API key entry (no expiry).
278    #[must_use]
279    pub fn new(name: impl Into<String>, hash: impl Into<String>, role: impl Into<String>) -> Self {
280        Self {
281            name: name.into(),
282            hash: hash.into(),
283            role: role.into(),
284            expires_at: None,
285        }
286    }
287
288    /// Set an RFC 3339 expiry on this key.
289    #[must_use]
290    pub fn with_expiry(mut self, expires_at: impl Into<String>) -> Self {
291        self.expires_at = Some(expires_at.into());
292        self
293    }
294}
295
296/// mTLS client certificate authentication configuration.
297#[derive(Debug, Clone, Deserialize)]
298#[allow(
299    clippy::struct_excessive_bools,
300    reason = "mTLS CRL behavior is intentionally configured as independent booleans"
301)]
302#[non_exhaustive]
303pub struct MtlsConfig {
304    /// Path to CA certificate(s) for verifying client certs (PEM format).
305    pub ca_cert_path: PathBuf,
306    /// If true, clients MUST present a valid certificate.
307    /// If false, client certs are optional (verified if presented).
308    #[serde(default)]
309    pub required: bool,
310    /// Default RBAC role for mTLS-authenticated clients.
311    /// The client cert CN becomes the identity name.
312    #[serde(default = "default_mtls_role")]
313    pub default_role: String,
314    /// Enable CRL-based certificate revocation checks using CDP URLs from the
315    /// configured CA chain and connecting client certificates.
316    #[serde(default = "default_true")]
317    pub crl_enabled: bool,
318    /// Optional fixed refresh interval for known CRLs. When omitted, refresh
319    /// cadence is derived from `nextUpdate` and clamped internally.
320    #[serde(default, with = "humantime_serde::option")]
321    pub crl_refresh_interval: Option<Duration>,
322    /// Timeout for individual CRL fetches.
323    #[serde(default = "default_crl_fetch_timeout", with = "humantime_serde")]
324    pub crl_fetch_timeout: Duration,
325    /// Grace window during which stale CRLs may still be used when refresh
326    /// attempts fail.
327    #[serde(default = "default_crl_stale_grace", with = "humantime_serde")]
328    pub crl_stale_grace: Duration,
329    /// When true, missing or unavailable CRLs cause revocation checks to fail
330    /// closed.
331    #[serde(default)]
332    pub crl_deny_on_unavailable: bool,
333    /// When true, apply revocation checks only to the end-entity certificate.
334    #[serde(default)]
335    pub crl_end_entity_only: bool,
336    /// Allow HTTP CRL distribution-point URLs in addition to HTTPS.
337    ///
338    /// Defaults to `true` because RFC 5280 §4.2.1.13 designates HTTP (and
339    /// LDAP) as the canonical transport for CRL distribution points.
340    /// SSRF defense for HTTP CDPs is provided by the IP-allowlist guard
341    /// (private/loopback/link-local/multicast/cloud-metadata addresses are
342    /// always rejected), redirect=none, body-size cap, and per-host
343    /// concurrency limit -- not by forcing HTTPS.
344    #[serde(default = "default_true")]
345    pub crl_allow_http: bool,
346    /// Enforce CRL expiration during certificate validation.
347    #[serde(default = "default_true")]
348    pub crl_enforce_expiration: bool,
349    /// Maximum concurrent CRL fetches across all hosts. Defense in depth
350    /// against SSRF amplification: even if many CDPs are discovered, no
351    /// more than this many fetches run in parallel. Per-host concurrency
352    /// is independently capped at 1 regardless of this value.
353    /// Default: `4`.
354    #[serde(default = "default_crl_max_concurrent_fetches")]
355    pub crl_max_concurrent_fetches: usize,
356    /// Hard cap on each CRL response body in bytes. Fetches exceeding this
357    /// are aborted mid-stream to bound memory and prevent gzip-bomb-style
358    /// amplification. Default: 5 MiB (`5 * 1024 * 1024`).
359    #[serde(default = "default_crl_max_response_bytes")]
360    pub crl_max_response_bytes: u64,
361    /// Global CDP discovery rate limit, in URLs per minute. Throttles
362    /// how many *new* CDP URLs the verifier may admit into the fetch
363    /// pipeline across the whole process, bounding asymmetric `DoS`
364    /// amplification when attacker-controlled certificates carry large
365    /// CDP lists. The limit is global (not per-source-IP) in this
366    /// release; per-IP scoping is deferred to a future version because
367    /// it requires plumbing the peer `SocketAddr` through the verifier
368    /// hook. URLs that lose the rate-limiter race are *not* marked as
369    /// seen, so subsequent handshakes observing the same URL can
370    /// retry admission.
371    /// Default: `60`.
372    #[serde(default = "default_crl_discovery_rate_per_min")]
373    pub crl_discovery_rate_per_min: u32,
374    /// Maximum number of distinct hosts that may hold a CRL fetch
375    /// semaphore at any time. Requests that would grow the map beyond
376    /// this cap return [`McpxError::Config`] containing the literal
377    /// substring `"crl_host_semaphore_cap_exceeded"`. Bounds memory
378    /// growth from attacker-controlled CDP URLs pointing at unique
379    /// hostnames. Default: 1024.
380    #[serde(default = "default_crl_max_host_semaphores")]
381    pub crl_max_host_semaphores: usize,
382    /// Maximum number of distinct URLs tracked in the "seen" set.
383    /// Beyond this, additional discovered URLs are silently dropped
384    /// with a rate-limited warn! log; no error surfaces. Default: 4096.
385    #[serde(default = "default_crl_max_seen_urls")]
386    pub crl_max_seen_urls: usize,
387    /// Maximum number of cached CRL entries. Beyond this, new
388    /// successful fetches are silently dropped with a rate-limited
389    /// warn! log (newest-rejected, not LRU-evicted). Default: 1024.
390    #[serde(default = "default_crl_max_cache_entries")]
391    pub crl_max_cache_entries: usize,
392}
393
394fn default_mtls_role() -> String {
395    "viewer".into()
396}
397
398const fn default_true() -> bool {
399    true
400}
401
402const fn default_crl_fetch_timeout() -> Duration {
403    Duration::from_secs(30)
404}
405
406const fn default_crl_stale_grace() -> Duration {
407    Duration::from_hours(24)
408}
409
410const fn default_crl_max_concurrent_fetches() -> usize {
411    4
412}
413
414const fn default_crl_max_response_bytes() -> u64 {
415    5 * 1024 * 1024
416}
417
418const fn default_crl_discovery_rate_per_min() -> u32 {
419    60
420}
421
422const fn default_crl_max_host_semaphores() -> usize {
423    1024
424}
425
426const fn default_crl_max_seen_urls() -> usize {
427    4096
428}
429
430const fn default_crl_max_cache_entries() -> usize {
431    1024
432}
433
434/// Rate limiting configuration for authentication attempts.
435///
436/// rmcp-server-kit uses two independent per-IP token-bucket limiters for auth:
437///
438/// 1. **Pre-auth abuse gate** ([`Self::pre_auth_max_per_minute`]): consulted
439///    *before* any password-hash work. Throttles unauthenticated traffic from
440///    a single source IP so an attacker cannot pin the CPU on Argon2id by
441///    spraying invalid bearer tokens. Sized generously (default = 10× the
442///    post-failure quota) so legitimate clients are unaffected. mTLS-
443///    authenticated connections bypass this gate entirely (the TLS handshake
444///    already performed expensive crypto with a verified peer).
445/// 2. **Post-failure backoff** ([`Self::max_attempts_per_minute`]): consulted
446///    *after* an authentication attempt fails. Provides explicit backpressure
447///    on bad credentials.
448#[derive(Debug, Clone, Deserialize)]
449#[non_exhaustive]
450pub struct RateLimitConfig {
451    /// Maximum failed authentication attempts per source IP per minute.
452    /// Successful authentications do not consume this budget.
453    #[serde(default = "default_max_attempts")]
454    pub max_attempts_per_minute: u32,
455    /// Maximum *unauthenticated* requests per source IP per minute admitted
456    /// to the password-hash verification path. When `None`, defaults to
457    /// `max_attempts_per_minute * 10` at limiter-construction time.
458    ///
459    /// Set higher than [`Self::max_attempts_per_minute`] so honest clients
460    /// retrying with the wrong key never trip this gate; its purpose is only
461    /// to bound CPU usage under spray attacks.
462    #[serde(default)]
463    pub pre_auth_max_per_minute: Option<u32>,
464    /// Hard cap on the number of distinct source IPs tracked per limiter.
465    /// When reached, idle entries are pruned first; if still full, the
466    /// oldest (LRU) entry is evicted to make room for the new one. This
467    /// bounds memory under IP-spray attacks. Default: `10_000`.
468    #[serde(default = "default_max_tracked_keys")]
469    pub max_tracked_keys: usize,
470    /// Per-IP entries idle for longer than this are eligible for
471    /// opportunistic pruning. Default: 15 minutes.
472    #[serde(default = "default_idle_eviction", with = "humantime_serde")]
473    pub idle_eviction: Duration,
474}
475
476impl Default for RateLimitConfig {
477    fn default() -> Self {
478        Self {
479            max_attempts_per_minute: default_max_attempts(),
480            pre_auth_max_per_minute: None,
481            max_tracked_keys: default_max_tracked_keys(),
482            idle_eviction: default_idle_eviction(),
483        }
484    }
485}
486
487impl RateLimitConfig {
488    /// Create a rate limit config with the given max failed attempts per minute.
489    /// Pre-auth gate defaults to `10x` this value at limiter-construction time.
490    /// Memory-bound defaults are `10_000` tracked keys with 15-minute idle eviction.
491    #[must_use]
492    pub fn new(max_attempts_per_minute: u32) -> Self {
493        Self {
494            max_attempts_per_minute,
495            ..Self::default()
496        }
497    }
498
499    /// Override the pre-auth abuse-gate quota (per source IP per minute).
500    /// When unset, defaults to `max_attempts_per_minute * 10`.
501    #[must_use]
502    pub fn with_pre_auth_max_per_minute(mut self, quota: u32) -> Self {
503        self.pre_auth_max_per_minute = Some(quota);
504        self
505    }
506
507    /// Override the per-limiter cap on tracked source-IP keys (default `10_000`).
508    #[must_use]
509    pub fn with_max_tracked_keys(mut self, max: usize) -> Self {
510        self.max_tracked_keys = max;
511        self
512    }
513
514    /// Override the idle-eviction window (default 15 minutes).
515    #[must_use]
516    pub fn with_idle_eviction(mut self, idle: Duration) -> Self {
517        self.idle_eviction = idle;
518        self
519    }
520}
521
522fn default_max_attempts() -> u32 {
523    30
524}
525
526fn default_max_tracked_keys() -> usize {
527    10_000
528}
529
530fn default_idle_eviction() -> Duration {
531    Duration::from_mins(15)
532}
533
534/// Authentication configuration.
535#[derive(Debug, Clone, Default, Deserialize)]
536#[non_exhaustive]
537pub struct AuthConfig {
538    /// Master switch - when false, all requests are allowed through.
539    #[serde(default)]
540    pub enabled: bool,
541    /// Bearer token API keys.
542    #[serde(default)]
543    pub api_keys: Vec<ApiKeyEntry>,
544    /// mTLS client certificate authentication.
545    pub mtls: Option<MtlsConfig>,
546    /// Rate limiting for auth attempts.
547    pub rate_limit: Option<RateLimitConfig>,
548    /// OAuth 2.1 JWT bearer token authentication.
549    #[cfg(feature = "oauth")]
550    pub oauth: Option<crate::oauth::OAuthConfig>,
551}
552
553impl AuthConfig {
554    /// Create an enabled auth config with the given API keys.
555    #[must_use]
556    pub fn with_keys(keys: Vec<ApiKeyEntry>) -> Self {
557        Self {
558            enabled: true,
559            api_keys: keys,
560            mtls: None,
561            rate_limit: None,
562            #[cfg(feature = "oauth")]
563            oauth: None,
564        }
565    }
566
567    /// Set rate limiting on this auth config.
568    #[must_use]
569    pub fn with_rate_limit(mut self, rate_limit: RateLimitConfig) -> Self {
570        self.rate_limit = Some(rate_limit);
571        self
572    }
573}
574
575/// Summary of a single API key suitable for admin endpoints.
576///
577/// Intentionally omits the Argon2id hash - only metadata is exposed.
578#[derive(Debug, Clone, serde::Serialize)]
579#[non_exhaustive]
580pub struct ApiKeySummary {
581    /// Human-readable key label.
582    pub name: String,
583    /// RBAC role granted when this key authenticates.
584    pub role: String,
585    /// Optional RFC 3339 expiry timestamp.
586    pub expires_at: Option<String>,
587}
588
589/// Snapshot of the enabled authentication methods for admin endpoints.
590#[derive(Debug, Clone, serde::Serialize)]
591#[allow(
592    clippy::struct_excessive_bools,
593    reason = "this is a flat summary of independent auth-method booleans"
594)]
595#[non_exhaustive]
596pub struct AuthConfigSummary {
597    /// Master enabled flag from config.
598    pub enabled: bool,
599    /// Whether API-key bearer auth is configured.
600    pub bearer: bool,
601    /// Whether mTLS client auth is configured.
602    pub mtls: bool,
603    /// Whether OAuth JWT validation is configured.
604    pub oauth: bool,
605    /// Current API-key list (no hashes).
606    pub api_keys: Vec<ApiKeySummary>,
607}
608
609impl AuthConfig {
610    /// Produce a hash-free summary of the auth config for admin endpoints.
611    #[must_use]
612    pub fn summary(&self) -> AuthConfigSummary {
613        AuthConfigSummary {
614            enabled: self.enabled,
615            bearer: !self.api_keys.is_empty(),
616            mtls: self.mtls.is_some(),
617            #[cfg(feature = "oauth")]
618            oauth: self.oauth.is_some(),
619            #[cfg(not(feature = "oauth"))]
620            oauth: false,
621            api_keys: self
622                .api_keys
623                .iter()
624                .map(|k| ApiKeySummary {
625                    name: k.name.clone(),
626                    role: k.role.clone(),
627                    expires_at: k.expires_at.clone(),
628                })
629                .collect(),
630        }
631    }
632}
633
634/// Keyed rate limiter type (per source IP). Memory-bounded by
635/// [`RateLimitConfig::max_tracked_keys`] to defend against IP-spray `DoS`.
636pub(crate) type KeyedLimiter = BoundedKeyedLimiter<IpAddr>;
637
638/// Connection info for TLS connections, carrying the peer socket address
639/// and (when mTLS is configured) the verified client identity extracted
640/// from the peer certificate during the TLS handshake.
641///
642/// Defined as a local type so we can implement axum's `Connected` trait
643/// for our custom `TlsListener` without orphan rule issues. The `identity`
644/// field travels with the connection itself (via the wrapping IO type),
645/// so there is no shared map to race against, no port-reuse aliasing, and
646/// no eviction policy to maintain.
647#[derive(Clone, Debug)]
648#[non_exhaustive]
649pub(crate) struct TlsConnInfo {
650    /// Remote peer socket address.
651    pub addr: SocketAddr,
652    /// Verified mTLS client identity, if a client certificate was presented
653    /// and successfully extracted during the TLS handshake.
654    pub identity: Option<AuthIdentity>,
655}
656
657impl TlsConnInfo {
658    /// Construct a new [`TlsConnInfo`].
659    #[must_use]
660    pub(crate) const fn new(addr: SocketAddr, identity: Option<AuthIdentity>) -> Self {
661        Self { addr, identity }
662    }
663}
664
665/// Shared state for the auth middleware.
666///
667/// `api_keys` uses [`ArcSwap`] so the SIGHUP handler can atomically
668/// swap in a new key list without blocking in-flight requests.
669#[allow(
670    missing_debug_implementations,
671    reason = "contains governor RateLimiter and JwksCache without Debug impls"
672)]
673#[non_exhaustive]
674pub(crate) struct AuthState {
675    /// Active set of API keys (hot-swappable).
676    pub api_keys: ArcSwap<Vec<ApiKeyEntry>>,
677    /// Optional per-IP post-failure rate limiter (consulted *after* auth fails).
678    pub rate_limiter: Option<Arc<KeyedLimiter>>,
679    /// Optional per-IP pre-auth abuse gate (consulted *before* password-hash work).
680    /// mTLS-authenticated connections bypass this gate.
681    pub pre_auth_limiter: Option<Arc<KeyedLimiter>>,
682    #[cfg(feature = "oauth")]
683    /// Optional JWKS cache for OAuth JWT validation.
684    pub jwks_cache: Option<Arc<crate::oauth::JwksCache>>,
685    /// Tracks identity names that have already been logged at INFO level.
686    /// Subsequent auths for the same identity are logged at DEBUG.
687    pub seen_identities: Mutex<HashSet<String>>,
688    /// Lightweight in-memory auth success/failure counters for diagnostics.
689    pub counters: AuthCounters,
690}
691
692impl AuthState {
693    /// Atomically replace the API key list (lock-free, wait-free).
694    ///
695    /// New requests immediately see the updated keys.
696    /// In-flight requests that already loaded the old list finish
697    /// using it -- no torn reads.
698    pub(crate) fn reload_keys(&self, keys: Vec<ApiKeyEntry>) {
699        let count = keys.len();
700        self.api_keys.store(Arc::new(keys));
701        tracing::info!(keys = count, "API keys reloaded");
702    }
703
704    /// Snapshot auth counters for diagnostics and tests.
705    #[must_use]
706    pub(crate) fn counters_snapshot(&self) -> AuthCountersSnapshot {
707        self.counters.snapshot()
708    }
709
710    /// Produce the admin-endpoint list of API keys (metadata only, no hashes).
711    #[must_use]
712    pub(crate) fn api_key_summaries(&self) -> Vec<ApiKeySummary> {
713        self.api_keys
714            .load()
715            .iter()
716            .map(|k| ApiKeySummary {
717                name: k.name.clone(),
718                role: k.role.clone(),
719                expires_at: k.expires_at.clone(),
720            })
721            .collect()
722    }
723
724    /// Log auth success: INFO on first occurrence per identity, DEBUG after.
725    fn log_auth(&self, id: &AuthIdentity, method: &str) {
726        self.counters.record_success(id.method);
727        let first = self
728            .seen_identities
729            .lock()
730            .unwrap_or_else(std::sync::PoisonError::into_inner)
731            .insert(id.name.clone());
732        if first {
733            tracing::info!(name = %id.name, role = %id.role, "{method} authenticated");
734        } else {
735            tracing::debug!(name = %id.name, role = %id.role, "{method} authenticated");
736        }
737    }
738}
739
740/// Default auth rate limit: 30 attempts per minute per source IP.
741// SAFETY: unwrap() is safe - literal 30 is provably non-zero (const-evaluated).
742const DEFAULT_AUTH_RATE: NonZeroU32 = NonZeroU32::new(30).unwrap();
743
744/// Create a post-failure rate limiter from config.
745#[must_use]
746pub(crate) fn build_rate_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
747    let quota = governor::Quota::per_minute(
748        NonZeroU32::new(config.max_attempts_per_minute).unwrap_or(DEFAULT_AUTH_RATE),
749    );
750    Arc::new(BoundedKeyedLimiter::new(
751        quota,
752        config.max_tracked_keys,
753        config.idle_eviction,
754    ))
755}
756
757/// Create a pre-auth abuse-gate rate limiter from config.
758///
759/// Quota: `pre_auth_max_per_minute` if set, otherwise
760/// `max_attempts_per_minute * 10` (capped at `u32::MAX`). The 10× factor
761/// keeps the gate generous enough for honest retries while still bounding
762/// attacker CPU on Argon2 verification.
763#[must_use]
764pub(crate) fn build_pre_auth_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
765    let resolved = config.pre_auth_max_per_minute.unwrap_or_else(|| {
766        config
767            .max_attempts_per_minute
768            .saturating_mul(PRE_AUTH_DEFAULT_MULTIPLIER)
769    });
770    let quota =
771        governor::Quota::per_minute(NonZeroU32::new(resolved).unwrap_or(DEFAULT_PRE_AUTH_RATE));
772    Arc::new(BoundedKeyedLimiter::new(
773        quota,
774        config.max_tracked_keys,
775        config.idle_eviction,
776    ))
777}
778
779/// Default multiplier applied to `max_attempts_per_minute` when the operator
780/// does not set `pre_auth_max_per_minute` explicitly.
781const PRE_AUTH_DEFAULT_MULTIPLIER: u32 = 10;
782
783/// Default pre-auth abuse-gate rate (used only if both the configured value
784/// and the multiplied fallback are zero, which `NonZeroU32::new` rejects).
785// SAFETY: unwrap() is safe - literal 300 is provably non-zero (const-evaluated).
786const DEFAULT_PRE_AUTH_RATE: NonZeroU32 = NonZeroU32::new(300).unwrap();
787
788/// Parse an mTLS client certificate and extract an `AuthIdentity`.
789///
790/// Reads the Subject CN as the identity name. Falls back to the first
791/// DNS SAN if CN is absent. The role is taken from the `MtlsConfig`.
792#[must_use]
793pub fn extract_mtls_identity(cert_der: &[u8], default_role: &str) -> Option<AuthIdentity> {
794    let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
795
796    // Try CN from Subject first.
797    let cn = cert
798        .subject()
799        .iter_common_name()
800        .next()
801        .and_then(|attr| attr.as_str().ok())
802        .map(String::from);
803
804    // Fall back to first DNS SAN.
805    let name = cn.or_else(|| {
806        cert.subject_alternative_name()
807            .ok()
808            .flatten()
809            .and_then(|san| {
810                #[allow(clippy::wildcard_enum_match_arm)]
811                san.value.general_names.iter().find_map(|gn| match gn {
812                    GeneralName::DNSName(dns) => Some((*dns).to_owned()),
813                    _ => None,
814                })
815            })
816    })?;
817
818    // Reject identities with characters unsafe for logging and RBAC matching.
819    if !name
820        .chars()
821        .all(|c| c.is_alphanumeric() || matches!(c, '-' | '.' | '_' | '@'))
822    {
823        tracing::warn!(cn = %name, "mTLS identity rejected: invalid characters in CN/SAN");
824        return None;
825    }
826
827    Some(AuthIdentity {
828        name,
829        role: default_role.to_owned(),
830        method: AuthMethod::MtlsCertificate,
831        raw_token: None,
832        sub: None,
833    })
834}
835
836/// Extract the bearer token from an `Authorization` header value.
837///
838/// Implements RFC 7235 §2.1: the auth-scheme token is **case-insensitive**.
839/// `Bearer`, `bearer`, `BEARER`, and `BeArEr` all parse equivalently. Any
840/// leading whitespace between the scheme and the token is trimmed (per
841/// RFC 7235 the separator is one or more SP characters; we accept the
842/// common single-space form plus tolerate extras).
843///
844/// Returns `None` if the header value:
845/// - does not contain a space (no scheme/credentials boundary), or
846/// - uses a scheme other than `Bearer` (case-insensitively).
847///
848/// The caller is responsible for token-level validation (length, charset,
849/// signature, etc.); this helper only handles the scheme prefix.
850fn extract_bearer(value: &str) -> Option<&str> {
851    let (scheme, rest) = value.split_once(' ')?;
852    if scheme.eq_ignore_ascii_case("Bearer") {
853        let token = rest.trim_start_matches(' ');
854        if token.is_empty() { None } else { Some(token) }
855    } else {
856        None
857    }
858}
859
860/// Verify a bearer token against configured API keys.
861///
862/// Argon2id verification is CPU-intensive, so this should be called via
863/// `spawn_blocking`. Returns the matching identity if the token is valid.
864///
865/// Iterates **all** keys to completion to prevent timing side-channels
866/// that would reveal how many keys exist or which slot matched.
867#[must_use]
868pub fn verify_bearer_token(token: &str, keys: &[ApiKeyEntry]) -> Option<AuthIdentity> {
869    let now = chrono::Utc::now();
870
871    // Always iterate ALL keys to completion to prevent timing side-channels
872    // that reveal how many keys exist or which position matched.
873    let mut result: Option<AuthIdentity> = None;
874
875    for key in keys {
876        // Check expiry
877        if let Some(ref expires) = key.expires_at
878            && let Ok(exp) = chrono::DateTime::parse_from_rfc3339(expires)
879            && exp < now
880        {
881            continue;
882        }
883
884        // Argon2id verification (constant-time internally).
885        // Keep the first match but continue checking remaining keys.
886        if result.is_none()
887            && let Ok(parsed_hash) = PasswordHash::new(&key.hash)
888            && Argon2::default()
889                .verify_password(token.as_bytes(), &parsed_hash)
890                .is_ok()
891        {
892            result = Some(AuthIdentity {
893                name: key.name.clone(),
894                role: key.role.clone(),
895                method: AuthMethod::BearerToken,
896                raw_token: None,
897                sub: None,
898            });
899        }
900    }
901    result
902}
903
904/// Generate a new API key: 256-bit random token + Argon2id hash.
905///
906/// Returns `(plaintext_token, argon2id_hash_phc_string)`.
907/// The plaintext is shown once to the user and never stored.
908///
909/// # Errors
910///
911/// Returns an error if salt encoding or Argon2id hashing fails
912/// (should not happen with valid inputs, but we avoid panicking).
913pub fn generate_api_key() -> Result<(String, String), McpxError> {
914    let mut token_bytes = [0u8; 32];
915    rand::fill(&mut token_bytes);
916    let token = URL_SAFE_NO_PAD.encode(token_bytes);
917
918    // Generate 16 random bytes for salt, encode as base64 for SaltString.
919    let mut salt_bytes = [0u8; 16];
920    rand::fill(&mut salt_bytes);
921    let salt = SaltString::encode_b64(&salt_bytes)
922        .map_err(|e| McpxError::Auth(format!("salt encoding failed: {e}")))?;
923    let hash = Argon2::default()
924        .hash_password(token.as_bytes(), &salt)
925        .map_err(|e| McpxError::Auth(format!("argon2id hashing failed: {e}")))?
926        .to_string();
927
928    Ok((token, hash))
929}
930
931fn build_www_authenticate_value(
932    advertise_resource_metadata: bool,
933    failure: AuthFailureClass,
934) -> String {
935    let (error, error_description) = failure.bearer_error();
936    if advertise_resource_metadata {
937        return format!(
938            "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\", error=\"{error}\", error_description=\"{error_description}\""
939        );
940    }
941    format!("Bearer error=\"{error}\", error_description=\"{error_description}\"")
942}
943
944fn auth_method_label(method: AuthMethod) -> &'static str {
945    match method {
946        AuthMethod::MtlsCertificate => "mTLS",
947        AuthMethod::BearerToken => "bearer token",
948        AuthMethod::OAuthJwt => "OAuth JWT",
949    }
950}
951
952#[cfg_attr(not(feature = "oauth"), allow(unused_variables))]
953fn unauthorized_response(state: &AuthState, failure_class: AuthFailureClass) -> Response {
954    #[cfg(feature = "oauth")]
955    let advertise_resource_metadata = state.jwks_cache.is_some();
956    #[cfg(not(feature = "oauth"))]
957    let advertise_resource_metadata = false;
958
959    let challenge = build_www_authenticate_value(advertise_resource_metadata, failure_class);
960    (
961        axum::http::StatusCode::UNAUTHORIZED,
962        [(header::WWW_AUTHENTICATE, challenge)],
963        failure_class.response_body(),
964    )
965        .into_response()
966}
967
968async fn authenticate_bearer_identity(
969    state: &AuthState,
970    token: &str,
971) -> Result<AuthIdentity, AuthFailureClass> {
972    let mut failure_class = AuthFailureClass::MissingCredential;
973
974    #[cfg(feature = "oauth")]
975    if let Some(ref cache) = state.jwks_cache
976        && crate::oauth::looks_like_jwt(token)
977    {
978        match cache.validate_token_with_reason(token).await {
979            Ok(mut id) => {
980                id.raw_token = Some(SecretString::from(token.to_owned()));
981                return Ok(id);
982            }
983            Err(crate::oauth::JwtValidationFailure::Expired) => {
984                failure_class = AuthFailureClass::ExpiredCredential;
985            }
986            Err(crate::oauth::JwtValidationFailure::Invalid) => {
987                failure_class = AuthFailureClass::InvalidCredential;
988            }
989        }
990    }
991
992    let token = token.to_owned();
993    let keys = state.api_keys.load_full(); // Arc clone, lock-free
994
995    // Argon2id is CPU-bound - offload to blocking thread pool.
996    let identity = tokio::task::spawn_blocking(move || verify_bearer_token(&token, &keys))
997        .await
998        .ok()
999        .flatten();
1000
1001    if let Some(id) = identity {
1002        return Ok(id);
1003    }
1004
1005    if failure_class == AuthFailureClass::MissingCredential {
1006        failure_class = AuthFailureClass::InvalidCredential;
1007    }
1008
1009    Err(failure_class)
1010}
1011
1012/// Consult the pre-auth abuse gate for the given peer.
1013///
1014/// Returns `Some(response)` if the request should be rejected (limiter
1015/// configured AND quota exhausted for this source IP). Returns `None`
1016/// otherwise (limiter absent, peer address unknown, or quota available),
1017/// in which case the caller should proceed with credential verification.
1018///
1019/// Side effects on rejection: increments the `pre_auth_gate` failure
1020/// counter and emits a warn-level log. mTLS-authenticated requests must
1021/// be admitted by the caller *before* invoking this helper.
1022fn pre_auth_gate(state: &AuthState, peer_addr: Option<SocketAddr>) -> Option<Response> {
1023    let limiter = state.pre_auth_limiter.as_ref()?;
1024    let addr = peer_addr?;
1025    if limiter.check_key(&addr.ip()).is_ok() {
1026        return None;
1027    }
1028    state.counters.record_failure(AuthFailureClass::PreAuthGate);
1029    tracing::warn!(
1030        ip = %addr.ip(),
1031        "auth rate limited by pre-auth gate (request rejected before credential verification)"
1032    );
1033    Some(
1034        McpxError::RateLimited("too many unauthenticated requests from this source".into())
1035            .into_response(),
1036    )
1037}
1038
1039/// Axum middleware that enforces authentication.
1040///
1041/// Tries authentication methods in priority order:
1042/// 1. mTLS client certificate identity (populated by TLS acceptor)
1043/// 2. Bearer token from `Authorization` header
1044///
1045/// Failed authentication attempts are rate-limited per source IP.
1046/// Successful authentications do not consume rate limit budget.
1047pub(crate) async fn auth_middleware(
1048    state: Arc<AuthState>,
1049    req: Request<Body>,
1050    next: Next,
1051) -> Response {
1052    // Extract peer address (and any mTLS identity) from ConnectInfo.
1053    // Plain TCP: ConnectInfo<SocketAddr>. TLS / mTLS: ConnectInfo<TlsConnInfo>,
1054    // which carries the verified identity directly on the connection — no
1055    // shared map, no port-reuse aliasing.
1056    let tls_info = req.extensions().get::<ConnectInfo<TlsConnInfo>>().cloned();
1057    let peer_addr = req
1058        .extensions()
1059        .get::<ConnectInfo<SocketAddr>>()
1060        .map(|ci| ci.0)
1061        .or_else(|| tls_info.as_ref().map(|ci| ci.0.addr));
1062
1063    // 1. Try mTLS identity (extracted by the TLS acceptor during handshake
1064    //    and attached to the connection itself).
1065    //
1066    //    mTLS connections bypass the pre-auth abuse gate below: the TLS
1067    //    handshake already performed expensive crypto with a verified peer,
1068    //    so we trust them not to be a CPU-spray attacker.
1069    if let Some(id) = tls_info.and_then(|ci| ci.0.identity) {
1070        state.log_auth(&id, "mTLS");
1071        let mut req = req;
1072        req.extensions_mut().insert(id);
1073        return next.run(req).await;
1074    }
1075
1076    // 2. Pre-auth abuse gate: rejects CPU-spray attacks BEFORE the Argon2id
1077    //    verification path runs. Keyed by source IP. mTLS connections (above)
1078    //    are exempt; this gate only protects the bearer/JWT verification path.
1079    if let Some(blocked) = pre_auth_gate(&state, peer_addr) {
1080        return blocked;
1081    }
1082
1083    let failure_class = if let Some(value) = req.headers().get(header::AUTHORIZATION) {
1084        match value.to_str().ok().and_then(extract_bearer) {
1085            Some(token) => match authenticate_bearer_identity(&state, token).await {
1086                Ok(id) => {
1087                    state.log_auth(&id, auth_method_label(id.method));
1088                    let mut req = req;
1089                    req.extensions_mut().insert(id);
1090                    return next.run(req).await;
1091                }
1092                Err(class) => class,
1093            },
1094            None => AuthFailureClass::InvalidCredential,
1095        }
1096    } else {
1097        AuthFailureClass::MissingCredential
1098    };
1099
1100    tracing::warn!(failure_class = %failure_class.as_str(), "auth failed");
1101
1102    // Rate limit check (applied after auth failure only).
1103    // Successful authentications do not consume rate limit budget.
1104    if let (Some(limiter), Some(addr)) = (&state.rate_limiter, peer_addr)
1105        && limiter.check_key(&addr.ip()).is_err()
1106    {
1107        state.counters.record_failure(AuthFailureClass::RateLimited);
1108        tracing::warn!(ip = %addr.ip(), "auth rate limited after repeated failures");
1109        return McpxError::RateLimited("too many failed authentication attempts".into())
1110            .into_response();
1111    }
1112
1113    state.counters.record_failure(failure_class);
1114    unauthorized_response(&state, failure_class)
1115}
1116
1117#[cfg(test)]
1118mod tests {
1119    use super::*;
1120
1121    #[test]
1122    fn generate_and_verify_api_key() {
1123        let (token, hash) = generate_api_key().unwrap();
1124
1125        // Token is 43 chars (256-bit base64url, no padding)
1126        assert_eq!(token.len(), 43);
1127
1128        // Hash is a valid PHC string
1129        assert!(hash.starts_with("$argon2id$"));
1130
1131        // Verification succeeds with correct token
1132        let keys = vec![ApiKeyEntry {
1133            name: "test".into(),
1134            hash,
1135            role: "viewer".into(),
1136            expires_at: None,
1137        }];
1138        let id = verify_bearer_token(&token, &keys);
1139        assert!(id.is_some());
1140        let id = id.unwrap();
1141        assert_eq!(id.name, "test");
1142        assert_eq!(id.role, "viewer");
1143        assert_eq!(id.method, AuthMethod::BearerToken);
1144    }
1145
1146    #[test]
1147    fn wrong_token_rejected() {
1148        let (_token, hash) = generate_api_key().unwrap();
1149        let keys = vec![ApiKeyEntry {
1150            name: "test".into(),
1151            hash,
1152            role: "viewer".into(),
1153            expires_at: None,
1154        }];
1155        assert!(verify_bearer_token("wrong-token", &keys).is_none());
1156    }
1157
1158    #[test]
1159    fn expired_key_rejected() {
1160        let (token, hash) = generate_api_key().unwrap();
1161        let keys = vec![ApiKeyEntry {
1162            name: "test".into(),
1163            hash,
1164            role: "viewer".into(),
1165            expires_at: Some("2020-01-01T00:00:00Z".into()),
1166        }];
1167        assert!(verify_bearer_token(&token, &keys).is_none());
1168    }
1169
1170    #[test]
1171    fn future_expiry_accepted() {
1172        let (token, hash) = generate_api_key().unwrap();
1173        let keys = vec![ApiKeyEntry {
1174            name: "test".into(),
1175            hash,
1176            role: "viewer".into(),
1177            expires_at: Some("2099-01-01T00:00:00Z".into()),
1178        }];
1179        assert!(verify_bearer_token(&token, &keys).is_some());
1180    }
1181
1182    #[test]
1183    fn multiple_keys_first_match_wins() {
1184        let (token, hash) = generate_api_key().unwrap();
1185        let keys = vec![
1186            ApiKeyEntry {
1187                name: "wrong".into(),
1188                hash: "$argon2id$v=19$m=19456,t=2,p=1$invalid$invalid".into(),
1189                role: "ops".into(),
1190                expires_at: None,
1191            },
1192            ApiKeyEntry {
1193                name: "correct".into(),
1194                hash,
1195                role: "deploy".into(),
1196                expires_at: None,
1197            },
1198        ];
1199        let id = verify_bearer_token(&token, &keys).unwrap();
1200        assert_eq!(id.name, "correct");
1201        assert_eq!(id.role, "deploy");
1202    }
1203
1204    #[test]
1205    fn rate_limiter_allows_within_quota() {
1206        let config = RateLimitConfig {
1207            max_attempts_per_minute: 5,
1208            pre_auth_max_per_minute: None,
1209            ..Default::default()
1210        };
1211        let limiter = build_rate_limiter(&config);
1212        let ip: IpAddr = "10.0.0.1".parse().unwrap();
1213
1214        // First 5 should succeed.
1215        for _ in 0..5 {
1216            assert!(limiter.check_key(&ip).is_ok());
1217        }
1218        // 6th should fail.
1219        assert!(limiter.check_key(&ip).is_err());
1220    }
1221
1222    #[test]
1223    fn rate_limiter_separate_ips() {
1224        let config = RateLimitConfig {
1225            max_attempts_per_minute: 2,
1226            pre_auth_max_per_minute: None,
1227            ..Default::default()
1228        };
1229        let limiter = build_rate_limiter(&config);
1230        let ip1: IpAddr = "10.0.0.1".parse().unwrap();
1231        let ip2: IpAddr = "10.0.0.2".parse().unwrap();
1232
1233        // Exhaust ip1's quota.
1234        assert!(limiter.check_key(&ip1).is_ok());
1235        assert!(limiter.check_key(&ip1).is_ok());
1236        assert!(limiter.check_key(&ip1).is_err());
1237
1238        // ip2 should still have quota.
1239        assert!(limiter.check_key(&ip2).is_ok());
1240    }
1241
1242    #[test]
1243    fn extract_mtls_identity_from_cn() {
1244        // Generate a cert with explicit CN.
1245        let mut params = rcgen::CertificateParams::new(vec!["test-client.local".into()]).unwrap();
1246        params.distinguished_name = rcgen::DistinguishedName::new();
1247        params
1248            .distinguished_name
1249            .push(rcgen::DnType::CommonName, "test-client");
1250        let cert = params
1251            .self_signed(&rcgen::KeyPair::generate().unwrap())
1252            .unwrap();
1253        let der = cert.der();
1254
1255        let id = extract_mtls_identity(der, "ops").unwrap();
1256        assert_eq!(id.name, "test-client");
1257        assert_eq!(id.role, "ops");
1258        assert_eq!(id.method, AuthMethod::MtlsCertificate);
1259    }
1260
1261    #[test]
1262    fn extract_mtls_identity_falls_back_to_san() {
1263        // Cert with no CN but has a DNS SAN.
1264        let mut params =
1265            rcgen::CertificateParams::new(vec!["san-only.example.com".into()]).unwrap();
1266        params.distinguished_name = rcgen::DistinguishedName::new();
1267        // No CN set - should fall back to DNS SAN.
1268        let cert = params
1269            .self_signed(&rcgen::KeyPair::generate().unwrap())
1270            .unwrap();
1271        let der = cert.der();
1272
1273        let id = extract_mtls_identity(der, "viewer").unwrap();
1274        assert_eq!(id.name, "san-only.example.com");
1275        assert_eq!(id.role, "viewer");
1276    }
1277
1278    #[test]
1279    fn extract_mtls_identity_invalid_der() {
1280        assert!(extract_mtls_identity(b"not-a-cert", "viewer").is_none());
1281    }
1282
1283    // -- auth_middleware integration tests --
1284
1285    use axum::{
1286        body::Body,
1287        http::{Request, StatusCode},
1288    };
1289    use tower::ServiceExt as _;
1290
1291    fn auth_router(state: Arc<AuthState>) -> axum::Router {
1292        axum::Router::new()
1293            .route("/mcp", axum::routing::post(|| async { "ok" }))
1294            .layer(axum::middleware::from_fn(move |req, next| {
1295                let s = Arc::clone(&state);
1296                auth_middleware(s, req, next)
1297            }))
1298    }
1299
1300    fn test_auth_state(keys: Vec<ApiKeyEntry>) -> Arc<AuthState> {
1301        Arc::new(AuthState {
1302            api_keys: ArcSwap::new(Arc::new(keys)),
1303            rate_limiter: None,
1304            pre_auth_limiter: None,
1305            #[cfg(feature = "oauth")]
1306            jwks_cache: None,
1307            seen_identities: Mutex::new(HashSet::new()),
1308            counters: AuthCounters::default(),
1309        })
1310    }
1311
1312    #[tokio::test]
1313    async fn middleware_rejects_no_credentials() {
1314        let state = test_auth_state(vec![]);
1315        let app = auth_router(Arc::clone(&state));
1316        let req = Request::builder()
1317            .method(axum::http::Method::POST)
1318            .uri("/mcp")
1319            .body(Body::empty())
1320            .unwrap();
1321        let resp = app.oneshot(req).await.unwrap();
1322        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1323        let challenge = resp
1324            .headers()
1325            .get(header::WWW_AUTHENTICATE)
1326            .unwrap()
1327            .to_str()
1328            .unwrap();
1329        assert!(challenge.contains("error=\"invalid_request\""));
1330
1331        let counters = state.counters_snapshot();
1332        assert_eq!(counters.failure_missing_credential, 1);
1333    }
1334
1335    #[tokio::test]
1336    async fn middleware_accepts_valid_bearer() {
1337        let (token, hash) = generate_api_key().unwrap();
1338        let keys = vec![ApiKeyEntry {
1339            name: "test-key".into(),
1340            hash,
1341            role: "ops".into(),
1342            expires_at: None,
1343        }];
1344        let state = test_auth_state(keys);
1345        let app = auth_router(Arc::clone(&state));
1346        let req = Request::builder()
1347            .method(axum::http::Method::POST)
1348            .uri("/mcp")
1349            .header("authorization", format!("Bearer {token}"))
1350            .body(Body::empty())
1351            .unwrap();
1352        let resp = app.oneshot(req).await.unwrap();
1353        assert_eq!(resp.status(), StatusCode::OK);
1354
1355        let counters = state.counters_snapshot();
1356        assert_eq!(counters.success_bearer, 1);
1357    }
1358
1359    #[tokio::test]
1360    async fn middleware_rejects_wrong_bearer() {
1361        let (_token, hash) = generate_api_key().unwrap();
1362        let keys = vec![ApiKeyEntry {
1363            name: "test-key".into(),
1364            hash,
1365            role: "ops".into(),
1366            expires_at: None,
1367        }];
1368        let state = test_auth_state(keys);
1369        let app = auth_router(Arc::clone(&state));
1370        let req = Request::builder()
1371            .method(axum::http::Method::POST)
1372            .uri("/mcp")
1373            .header("authorization", "Bearer wrong-token-here")
1374            .body(Body::empty())
1375            .unwrap();
1376        let resp = app.oneshot(req).await.unwrap();
1377        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1378        let challenge = resp
1379            .headers()
1380            .get(header::WWW_AUTHENTICATE)
1381            .unwrap()
1382            .to_str()
1383            .unwrap();
1384        assert!(challenge.contains("error=\"invalid_token\""));
1385
1386        let counters = state.counters_snapshot();
1387        assert_eq!(counters.failure_invalid_credential, 1);
1388    }
1389
1390    #[tokio::test]
1391    async fn middleware_rate_limits() {
1392        let state = Arc::new(AuthState {
1393            api_keys: ArcSwap::new(Arc::new(vec![])),
1394            rate_limiter: Some(build_rate_limiter(&RateLimitConfig {
1395                max_attempts_per_minute: 1,
1396                pre_auth_max_per_minute: None,
1397                ..Default::default()
1398            })),
1399            pre_auth_limiter: None,
1400            #[cfg(feature = "oauth")]
1401            jwks_cache: None,
1402            seen_identities: Mutex::new(HashSet::new()),
1403            counters: AuthCounters::default(),
1404        });
1405        let app = auth_router(state);
1406
1407        // First request: UNAUTHORIZED (no credentials, but not rate limited)
1408        let req = Request::builder()
1409            .method(axum::http::Method::POST)
1410            .uri("/mcp")
1411            .body(Body::empty())
1412            .unwrap();
1413        let resp = app.clone().oneshot(req).await.unwrap();
1414        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1415
1416        // Second request from same "IP" (no ConnectInfo in test, so peer_addr is None
1417        // and rate limiter won't fire). That's expected -- rate limiting requires
1418        // ConnectInfo which isn't available in unit tests without a real server.
1419        // This test verifies the middleware wiring doesn't panic.
1420    }
1421
1422    /// Verify that rate limit semantics: only failed auth attempts consume budget.
1423    ///
1424    /// This is a unit test of the limiter behavior. The middleware integration
1425    /// is that on auth failure, `check_key` is called; on auth success, it is NOT.
1426    /// Full e2e tests verify the middleware routing but require `ConnectInfo`.
1427    #[test]
1428    fn rate_limit_semantics_failed_only() {
1429        let config = RateLimitConfig {
1430            max_attempts_per_minute: 3,
1431            pre_auth_max_per_minute: None,
1432            ..Default::default()
1433        };
1434        let limiter = build_rate_limiter(&config);
1435        let ip: IpAddr = "192.168.1.100".parse().unwrap();
1436
1437        // Simulate: 3 failed attempts should exhaust quota.
1438        assert!(
1439            limiter.check_key(&ip).is_ok(),
1440            "failure 1 should be allowed"
1441        );
1442        assert!(
1443            limiter.check_key(&ip).is_ok(),
1444            "failure 2 should be allowed"
1445        );
1446        assert!(
1447            limiter.check_key(&ip).is_ok(),
1448            "failure 3 should be allowed"
1449        );
1450        assert!(
1451            limiter.check_key(&ip).is_err(),
1452            "failure 4 should be blocked"
1453        );
1454
1455        // In the actual middleware flow:
1456        // - Successful auth: verify_bearer_token returns Some, we return early
1457        //   WITHOUT calling check_key, so no budget consumed.
1458        // - Failed auth: verify_bearer_token returns None, we call check_key
1459        //   THEN return 401, so budget is consumed.
1460        //
1461        // This means N successful requests followed by M failed requests
1462        // will only count M toward the rate limit, not N+M.
1463    }
1464
1465    // -- pre-auth abuse gate (H-S1) --
1466
1467    /// The pre-auth gate must default to ~10x the post-failure quota so honest
1468    /// retry storms never trip it but a Argon2-spray attacker is throttled.
1469    #[test]
1470    fn pre_auth_default_multiplier_is_10x() {
1471        let config = RateLimitConfig {
1472            max_attempts_per_minute: 5,
1473            pre_auth_max_per_minute: None,
1474            ..Default::default()
1475        };
1476        let limiter = build_pre_auth_limiter(&config);
1477        let ip: IpAddr = "10.0.0.1".parse().unwrap();
1478
1479        // Quota should be 50 (5 * 10), not 5. We expect the first 50 to pass.
1480        for i in 0..50 {
1481            assert!(
1482                limiter.check_key(&ip).is_ok(),
1483                "pre-auth attempt {i} (of expected 50) should be allowed under default 10x multiplier"
1484            );
1485        }
1486        // The 51st attempt must be blocked: confirms quota is bounded, not infinite.
1487        assert!(
1488            limiter.check_key(&ip).is_err(),
1489            "pre-auth attempt 51 should be blocked (quota is 50, not unbounded)"
1490        );
1491    }
1492
1493    /// An explicit `pre_auth_max_per_minute` override must win over the
1494    /// 10x-multiplier default.
1495    #[test]
1496    fn pre_auth_explicit_override_wins() {
1497        let config = RateLimitConfig {
1498            max_attempts_per_minute: 100,     // would default to 1000 pre-auth quota
1499            pre_auth_max_per_minute: Some(2), // but operator caps at 2
1500            ..Default::default()
1501        };
1502        let limiter = build_pre_auth_limiter(&config);
1503        let ip: IpAddr = "10.0.0.2".parse().unwrap();
1504
1505        assert!(limiter.check_key(&ip).is_ok(), "attempt 1 allowed");
1506        assert!(limiter.check_key(&ip).is_ok(), "attempt 2 allowed");
1507        assert!(
1508            limiter.check_key(&ip).is_err(),
1509            "attempt 3 must be blocked (explicit override of 2 wins over 10x default of 1000)"
1510        );
1511    }
1512
1513    /// End-to-end: the pre-auth gate must reject before the bearer-verification
1514    /// path runs. We exhaust the gate's quota (Some(1)) with one bad-bearer
1515    /// request, then the second request must be rejected with 429 + the
1516    /// `pre_auth_gate` failure counter incremented (NOT
1517    /// `failure_invalid_credential`, which would prove Argon2 ran).
1518    #[tokio::test]
1519    async fn pre_auth_gate_blocks_before_argon2_verification() {
1520        let (_token, hash) = generate_api_key().unwrap();
1521        let keys = vec![ApiKeyEntry {
1522            name: "test-key".into(),
1523            hash,
1524            role: "ops".into(),
1525            expires_at: None,
1526        }];
1527        let config = RateLimitConfig {
1528            max_attempts_per_minute: 100,
1529            pre_auth_max_per_minute: Some(1),
1530            ..Default::default()
1531        };
1532        let state = Arc::new(AuthState {
1533            api_keys: ArcSwap::new(Arc::new(keys)),
1534            rate_limiter: None,
1535            pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1536            #[cfg(feature = "oauth")]
1537            jwks_cache: None,
1538            seen_identities: Mutex::new(HashSet::new()),
1539            counters: AuthCounters::default(),
1540        });
1541        let app = auth_router(Arc::clone(&state));
1542        let peer: SocketAddr = "10.0.0.10:54321".parse().unwrap();
1543
1544        // First bad-bearer request: gate has quota, bearer verification runs,
1545        // returns 401 (invalid credential).
1546        let mut req1 = Request::builder()
1547            .method(axum::http::Method::POST)
1548            .uri("/mcp")
1549            .header("authorization", "Bearer obviously-not-a-real-token")
1550            .body(Body::empty())
1551            .unwrap();
1552        req1.extensions_mut().insert(ConnectInfo(peer));
1553        let resp1 = app.clone().oneshot(req1).await.unwrap();
1554        assert_eq!(
1555            resp1.status(),
1556            StatusCode::UNAUTHORIZED,
1557            "first attempt: gate has quota, falls through to bearer auth which fails with 401"
1558        );
1559
1560        // Second bad-bearer request from same IP: gate quota exhausted, must
1561        // reject with 429 BEFORE the Argon2 verification path runs.
1562        let mut req2 = Request::builder()
1563            .method(axum::http::Method::POST)
1564            .uri("/mcp")
1565            .header("authorization", "Bearer also-not-a-real-token")
1566            .body(Body::empty())
1567            .unwrap();
1568        req2.extensions_mut().insert(ConnectInfo(peer));
1569        let resp2 = app.oneshot(req2).await.unwrap();
1570        assert_eq!(
1571            resp2.status(),
1572            StatusCode::TOO_MANY_REQUESTS,
1573            "second attempt from same IP: pre-auth gate must reject with 429"
1574        );
1575
1576        let counters = state.counters_snapshot();
1577        assert_eq!(
1578            counters.failure_pre_auth_gate, 1,
1579            "exactly one request must have been rejected by the pre-auth gate"
1580        );
1581        // Critical: Argon2 verification must NOT have run on the gated request.
1582        // The first request's 401 increments `failure_invalid_credential` to 1;
1583        // the second (gated) request must NOT increment it further.
1584        assert_eq!(
1585            counters.failure_invalid_credential, 1,
1586            "bearer verification must run exactly once (only the un-gated first request)"
1587        );
1588    }
1589
1590    /// mTLS-authenticated requests must bypass the pre-auth gate entirely.
1591    /// The TLS handshake already performed expensive crypto with a verified
1592    /// peer, so mTLS callers should never be throttled by this gate.
1593    ///
1594    /// Setup: a pre-auth gate with quota 1 (very tight). Submit two mTLS
1595    /// requests in quick succession from the same IP. Both must succeed.
1596    #[tokio::test]
1597    async fn pre_auth_gate_does_not_throttle_mtls() {
1598        let config = RateLimitConfig {
1599            max_attempts_per_minute: 100,
1600            pre_auth_max_per_minute: Some(1), // tight: would block 2nd plain request
1601            ..Default::default()
1602        };
1603        let state = Arc::new(AuthState {
1604            api_keys: ArcSwap::new(Arc::new(vec![])),
1605            rate_limiter: None,
1606            pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1607            #[cfg(feature = "oauth")]
1608            jwks_cache: None,
1609            seen_identities: Mutex::new(HashSet::new()),
1610            counters: AuthCounters::default(),
1611        });
1612        let app = auth_router(Arc::clone(&state));
1613        let peer: SocketAddr = "10.0.0.20:54321".parse().unwrap();
1614        let identity = AuthIdentity {
1615            name: "cn=test-client".into(),
1616            role: "viewer".into(),
1617            method: AuthMethod::MtlsCertificate,
1618            raw_token: None,
1619            sub: None,
1620        };
1621        let tls_info = TlsConnInfo::new(peer, Some(identity));
1622
1623        for i in 0..3 {
1624            let mut req = Request::builder()
1625                .method(axum::http::Method::POST)
1626                .uri("/mcp")
1627                .body(Body::empty())
1628                .unwrap();
1629            req.extensions_mut().insert(ConnectInfo(tls_info.clone()));
1630            let resp = app.clone().oneshot(req).await.unwrap();
1631            assert_eq!(
1632                resp.status(),
1633                StatusCode::OK,
1634                "mTLS request {i} must succeed: pre-auth gate must not apply to mTLS callers"
1635            );
1636        }
1637
1638        let counters = state.counters_snapshot();
1639        assert_eq!(
1640            counters.failure_pre_auth_gate, 0,
1641            "pre-auth gate counter must remain at zero: mTLS bypasses the gate"
1642        );
1643        assert_eq!(
1644            counters.success_mtls, 3,
1645            "all three mTLS requests must have been counted as successful"
1646        );
1647    }
1648
1649    // -------------------------------------------------------------------
1650    // RFC 7235 §2.1 case-insensitive scheme parsing for `extract_bearer`.
1651    // -------------------------------------------------------------------
1652
1653    #[test]
1654    fn extract_bearer_accepts_canonical_case() {
1655        assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
1656    }
1657
1658    #[test]
1659    fn extract_bearer_is_case_insensitive_per_rfc7235() {
1660        // RFC 7235 §2.1: "auth-scheme is case-insensitive".
1661        // Real-world clients (curl, browsers, custom HTTP libs) emit varied
1662        // casings; rejecting any of them is a spec violation.
1663        for header in &[
1664            "bearer abc123",
1665            "BEARER abc123",
1666            "BeArEr abc123",
1667            "bEaReR abc123",
1668        ] {
1669            assert_eq!(
1670                extract_bearer(header),
1671                Some("abc123"),
1672                "header {header:?} must parse as a Bearer token (RFC 7235 §2.1)"
1673            );
1674        }
1675    }
1676
1677    #[test]
1678    fn extract_bearer_rejects_other_schemes() {
1679        assert_eq!(extract_bearer("Basic dXNlcjpwYXNz"), None);
1680        assert_eq!(extract_bearer("Digest username=\"x\""), None);
1681        assert_eq!(extract_bearer("Token abc123"), None);
1682    }
1683
1684    #[test]
1685    fn extract_bearer_rejects_malformed() {
1686        // Empty string, no separator, scheme-only, scheme + only whitespace.
1687        assert_eq!(extract_bearer(""), None);
1688        assert_eq!(extract_bearer("Bearer"), None);
1689        assert_eq!(extract_bearer("Bearer "), None);
1690        assert_eq!(extract_bearer("Bearer    "), None);
1691    }
1692
1693    #[test]
1694    fn extract_bearer_tolerates_extra_separator_whitespace() {
1695        // Some non-conformant clients emit two spaces; we should still parse.
1696        assert_eq!(extract_bearer("Bearer  abc123"), Some("abc123"));
1697        assert_eq!(extract_bearer("Bearer   abc123"), Some("abc123"));
1698    }
1699
1700    // -------------------------------------------------------------------
1701    // Debug redaction: ensure `AuthIdentity` and `ApiKeyEntry` never leak
1702    // secret material via `format!("{:?}", …)` or `tracing::debug!(?…)`.
1703    // -------------------------------------------------------------------
1704
1705    #[test]
1706    fn auth_identity_debug_redacts_raw_token() {
1707        let id = AuthIdentity {
1708            name: "alice".into(),
1709            role: "admin".into(),
1710            method: AuthMethod::OAuthJwt,
1711            raw_token: Some(SecretString::from("super-secret-jwt-payload-xyz")),
1712            sub: Some("keycloak-uuid-2f3c8b".into()),
1713        };
1714        let dbg = format!("{id:?}");
1715
1716        // Plaintext fields must be visible (they are not secrets).
1717        assert!(dbg.contains("alice"), "name should be visible: {dbg}");
1718        assert!(dbg.contains("admin"), "role should be visible: {dbg}");
1719        assert!(dbg.contains("OAuthJwt"), "method should be visible: {dbg}");
1720
1721        // Secret fields must NOT leak.
1722        assert!(
1723            !dbg.contains("super-secret-jwt-payload-xyz"),
1724            "raw_token must be redacted in Debug output: {dbg}"
1725        );
1726        assert!(
1727            !dbg.contains("keycloak-uuid-2f3c8b"),
1728            "sub must be redacted in Debug output: {dbg}"
1729        );
1730        assert!(
1731            dbg.contains("<redacted>"),
1732            "redaction marker missing: {dbg}"
1733        );
1734    }
1735
1736    #[test]
1737    fn auth_identity_debug_marks_absent_secrets() {
1738        // For non-OAuth identities (mTLS / API key) the secret fields are
1739        // None; redacted Debug output should distinguish that from "present".
1740        let id = AuthIdentity {
1741            name: "viewer-key".into(),
1742            role: "viewer".into(),
1743            method: AuthMethod::BearerToken,
1744            raw_token: None,
1745            sub: None,
1746        };
1747        let dbg = format!("{id:?}");
1748        assert!(
1749            dbg.contains("<none>"),
1750            "absent secrets should be marked: {dbg}"
1751        );
1752        assert!(
1753            !dbg.contains("<redacted>"),
1754            "no <redacted> marker when secrets are absent: {dbg}"
1755        );
1756    }
1757
1758    #[test]
1759    fn api_key_entry_debug_redacts_hash() {
1760        let entry = ApiKeyEntry {
1761            name: "viewer-key".into(),
1762            // Realistic Argon2id PHC string (must NOT leak).
1763            hash: "$argon2id$v=19$m=19456,t=2,p=1$c2FsdHNhbHQ$h4sh3dPa55w0rd".into(),
1764            role: "viewer".into(),
1765            expires_at: Some("2030-01-01T00:00:00Z".into()),
1766        };
1767        let dbg = format!("{entry:?}");
1768
1769        // Non-secret fields visible.
1770        assert!(dbg.contains("viewer-key"));
1771        assert!(dbg.contains("viewer"));
1772        assert!(dbg.contains("2030-01-01T00:00:00Z"));
1773
1774        // Hash material must NOT leak.
1775        assert!(
1776            !dbg.contains("$argon2id$"),
1777            "argon2 hash leaked into Debug output: {dbg}"
1778        );
1779        assert!(
1780            !dbg.contains("h4sh3dPa55w0rd"),
1781            "hash digest leaked into Debug output: {dbg}"
1782        );
1783        assert!(
1784            dbg.contains("<redacted>"),
1785            "redaction marker missing: {dbg}"
1786        );
1787    }
1788}