1use std::{
10 collections::HashSet,
11 net::{IpAddr, SocketAddr},
12 num::NonZeroU32,
13 path::PathBuf,
14 sync::{
15 Arc, LazyLock, Mutex,
16 atomic::{AtomicU64, Ordering},
17 },
18 time::Duration,
19};
20
21use arc_swap::ArcSwap;
22use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
23use axum::{
24 body::Body,
25 extract::ConnectInfo,
26 http::{Request, header},
27 middleware::Next,
28 response::{IntoResponse, Response},
29};
30use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
31use secrecy::SecretString;
32use serde::Deserialize;
33use x509_parser::prelude::*;
34
35use crate::{bounded_limiter::BoundedKeyedLimiter, error::McpxError};
36
37#[derive(Clone)]
46#[non_exhaustive]
47pub struct AuthIdentity {
48 pub name: String,
50 pub role: String,
52 pub method: AuthMethod,
54 pub raw_token: Option<SecretString>,
60 pub sub: Option<String>,
63}
64
65impl std::fmt::Debug for AuthIdentity {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("AuthIdentity")
70 .field("name", &self.name)
71 .field("role", &self.role)
72 .field("method", &self.method)
73 .field(
74 "raw_token",
75 &if self.raw_token.is_some() {
76 "<redacted>"
77 } else {
78 "<none>"
79 },
80 )
81 .field(
82 "sub",
83 &if self.sub.is_some() {
84 "<redacted>"
85 } else {
86 "<none>"
87 },
88 )
89 .finish()
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95#[non_exhaustive]
96pub enum AuthMethod {
97 BearerToken,
99 MtlsCertificate,
101 OAuthJwt,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106enum AuthFailureClass {
107 MissingCredential,
108 InvalidCredential,
109 #[cfg_attr(not(feature = "oauth"), allow(dead_code))]
110 ExpiredCredential,
111 RateLimited,
113 PreAuthGate,
116}
117
118impl AuthFailureClass {
119 fn as_str(self) -> &'static str {
120 match self {
121 Self::MissingCredential => "missing_credential",
122 Self::InvalidCredential => "invalid_credential",
123 Self::ExpiredCredential => "expired_credential",
124 Self::RateLimited => "rate_limited",
125 Self::PreAuthGate => "pre_auth_gate",
126 }
127 }
128
129 fn bearer_error(self) -> (&'static str, &'static str) {
130 match self {
131 Self::MissingCredential => (
132 "invalid_request",
133 "missing bearer token or mTLS client certificate",
134 ),
135 Self::InvalidCredential => ("invalid_token", "token is invalid"),
136 Self::ExpiredCredential => ("invalid_token", "token is expired"),
137 Self::RateLimited => ("invalid_request", "too many failed authentication attempts"),
138 Self::PreAuthGate => (
139 "invalid_request",
140 "too many unauthenticated requests from this source",
141 ),
142 }
143 }
144
145 fn response_body(self) -> &'static str {
146 match self {
147 Self::MissingCredential => "unauthorized: missing credential",
148 Self::InvalidCredential => "unauthorized: invalid credential",
149 Self::ExpiredCredential => "unauthorized: expired credential",
150 Self::RateLimited => "rate limited",
151 Self::PreAuthGate => "rate limited (pre-auth)",
152 }
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
158#[non_exhaustive]
159pub struct AuthCountersSnapshot {
160 pub success_mtls: u64,
162 pub success_bearer: u64,
164 pub success_oauth_jwt: u64,
166 pub failure_missing_credential: u64,
168 pub failure_invalid_credential: u64,
170 pub failure_expired_credential: u64,
172 pub failure_rate_limited: u64,
174 pub failure_pre_auth_gate: u64,
177}
178
179#[derive(Debug, Default)]
181pub(crate) struct AuthCounters {
182 success_mtls: AtomicU64,
183 success_bearer: AtomicU64,
184 success_oauth_jwt: AtomicU64,
185 failure_missing_credential: AtomicU64,
186 failure_invalid_credential: AtomicU64,
187 failure_expired_credential: AtomicU64,
188 failure_rate_limited: AtomicU64,
189 failure_pre_auth_gate: AtomicU64,
190}
191
192impl AuthCounters {
193 fn record_success(&self, method: AuthMethod) {
194 match method {
195 AuthMethod::MtlsCertificate => {
196 self.success_mtls.fetch_add(1, Ordering::Relaxed);
197 }
198 AuthMethod::BearerToken => {
199 self.success_bearer.fetch_add(1, Ordering::Relaxed);
200 }
201 AuthMethod::OAuthJwt => {
202 self.success_oauth_jwt.fetch_add(1, Ordering::Relaxed);
203 }
204 }
205 }
206
207 fn record_failure(&self, class: AuthFailureClass) {
208 match class {
209 AuthFailureClass::MissingCredential => {
210 self.failure_missing_credential
211 .fetch_add(1, Ordering::Relaxed);
212 }
213 AuthFailureClass::InvalidCredential => {
214 self.failure_invalid_credential
215 .fetch_add(1, Ordering::Relaxed);
216 }
217 AuthFailureClass::ExpiredCredential => {
218 self.failure_expired_credential
219 .fetch_add(1, Ordering::Relaxed);
220 }
221 AuthFailureClass::RateLimited => {
222 self.failure_rate_limited.fetch_add(1, Ordering::Relaxed);
223 }
224 AuthFailureClass::PreAuthGate => {
225 self.failure_pre_auth_gate.fetch_add(1, Ordering::Relaxed);
226 }
227 }
228 }
229
230 fn snapshot(&self) -> AuthCountersSnapshot {
231 AuthCountersSnapshot {
232 success_mtls: self.success_mtls.load(Ordering::Relaxed),
233 success_bearer: self.success_bearer.load(Ordering::Relaxed),
234 success_oauth_jwt: self.success_oauth_jwt.load(Ordering::Relaxed),
235 failure_missing_credential: self.failure_missing_credential.load(Ordering::Relaxed),
236 failure_invalid_credential: self.failure_invalid_credential.load(Ordering::Relaxed),
237 failure_expired_credential: self.failure_expired_credential.load(Ordering::Relaxed),
238 failure_rate_limited: self.failure_rate_limited.load(Ordering::Relaxed),
239 failure_pre_auth_gate: self.failure_pre_auth_gate.load(Ordering::Relaxed),
240 }
241 }
242}
243
244#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
256#[non_exhaustive]
257pub struct RfcTimestamp(chrono::DateTime<chrono::FixedOffset>);
258
259impl RfcTimestamp {
260 pub fn parse(s: &str) -> Result<Self, chrono::ParseError> {
268 chrono::DateTime::parse_from_rfc3339(s).map(Self)
269 }
270
271 #[must_use]
273 pub fn as_datetime(&self) -> &chrono::DateTime<chrono::FixedOffset> {
274 &self.0
275 }
276
277 #[must_use]
279 pub fn into_inner(self) -> chrono::DateTime<chrono::FixedOffset> {
280 self.0
281 }
282}
283
284impl std::fmt::Display for RfcTimestamp {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 write!(f, "{}", self.0.to_rfc3339())
288 }
289}
290
291impl std::fmt::Debug for RfcTimestamp {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 write!(f, "{}", self.0.to_rfc3339())
298 }
299}
300
301impl<'de> Deserialize<'de> for RfcTimestamp {
302 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
303 where
304 D: serde::Deserializer<'de>,
305 {
306 let s = String::deserialize(deserializer)?;
310 Self::parse(&s).map_err(serde::de::Error::custom)
311 }
312}
313
314impl serde::Serialize for RfcTimestamp {
315 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
316 where
317 S: serde::Serializer,
318 {
319 serializer.serialize_str(&self.0.to_rfc3339())
320 }
321}
322
323impl From<chrono::DateTime<chrono::FixedOffset>> for RfcTimestamp {
324 fn from(value: chrono::DateTime<chrono::FixedOffset>) -> Self {
325 Self(value)
326 }
327}
328
329#[derive(Clone, Deserialize)]
336#[non_exhaustive]
337pub struct ApiKeyEntry {
338 pub name: String,
340 pub hash: String,
342 pub role: String,
344 pub expires_at: Option<RfcTimestamp>,
349}
350
351impl std::fmt::Debug for ApiKeyEntry {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 f.debug_struct("ApiKeyEntry")
356 .field("name", &self.name)
357 .field("hash", &"<redacted>")
358 .field("role", &self.role)
359 .field("expires_at", &self.expires_at)
360 .finish()
361 }
362}
363
364impl ApiKeyEntry {
365 #[must_use]
367 pub fn new(name: impl Into<String>, hash: impl Into<String>, role: impl Into<String>) -> Self {
368 Self {
369 name: name.into(),
370 hash: hash.into(),
371 role: role.into(),
372 expires_at: None,
373 }
374 }
375
376 #[must_use]
381 pub fn with_expiry(mut self, expires_at: RfcTimestamp) -> Self {
382 self.expires_at = Some(expires_at);
383 self
384 }
385
386 pub fn try_with_expiry(
394 mut self,
395 expires_at: impl AsRef<str>,
396 ) -> Result<Self, chrono::ParseError> {
397 self.expires_at = Some(RfcTimestamp::parse(expires_at.as_ref())?);
398 Ok(self)
399 }
400}
401
402#[derive(Debug, Clone, Deserialize)]
404#[allow(
405 clippy::struct_excessive_bools,
406 reason = "mTLS CRL behavior is intentionally configured as independent booleans"
407)]
408#[non_exhaustive]
409pub struct MtlsConfig {
410 pub ca_cert_path: PathBuf,
412 #[serde(default)]
415 pub required: bool,
416 #[serde(default = "default_mtls_role")]
419 pub default_role: String,
420 #[serde(default = "default_true")]
423 pub crl_enabled: bool,
424 #[serde(default, with = "humantime_serde::option")]
427 pub crl_refresh_interval: Option<Duration>,
428 #[serde(default = "default_crl_fetch_timeout", with = "humantime_serde")]
430 pub crl_fetch_timeout: Duration,
431 #[serde(default = "default_crl_stale_grace", with = "humantime_serde")]
434 pub crl_stale_grace: Duration,
435 #[serde(default)]
438 pub crl_deny_on_unavailable: bool,
439 #[serde(default)]
441 pub crl_end_entity_only: bool,
442 #[serde(default = "default_true")]
451 pub crl_allow_http: bool,
452 #[serde(default = "default_true")]
454 pub crl_enforce_expiration: bool,
455 #[serde(default = "default_crl_max_concurrent_fetches")]
461 pub crl_max_concurrent_fetches: usize,
462 #[serde(default = "default_crl_max_response_bytes")]
466 pub crl_max_response_bytes: u64,
467 #[serde(default = "default_crl_discovery_rate_per_min")]
479 pub crl_discovery_rate_per_min: u32,
480 #[serde(default = "default_crl_max_host_semaphores")]
487 pub crl_max_host_semaphores: usize,
488 #[serde(default = "default_crl_max_seen_urls")]
492 pub crl_max_seen_urls: usize,
493 #[serde(default = "default_crl_max_cache_entries")]
497 pub crl_max_cache_entries: usize,
498}
499
500fn default_mtls_role() -> String {
501 "viewer".into()
502}
503
504const fn default_true() -> bool {
505 true
506}
507
508const fn default_crl_fetch_timeout() -> Duration {
509 Duration::from_secs(30)
510}
511
512const fn default_crl_stale_grace() -> Duration {
513 Duration::from_hours(24)
514}
515
516const fn default_crl_max_concurrent_fetches() -> usize {
517 4
518}
519
520const fn default_crl_max_response_bytes() -> u64 {
521 5 * 1024 * 1024
522}
523
524const fn default_crl_discovery_rate_per_min() -> u32 {
525 60
526}
527
528const fn default_crl_max_host_semaphores() -> usize {
529 1024
530}
531
532const fn default_crl_max_seen_urls() -> usize {
533 4096
534}
535
536const fn default_crl_max_cache_entries() -> usize {
537 1024
538}
539
540#[derive(Debug, Clone, Deserialize)]
555#[non_exhaustive]
556pub struct RateLimitConfig {
557 #[serde(default = "default_max_attempts")]
560 pub max_attempts_per_minute: u32,
561 #[serde(default)]
569 pub pre_auth_max_per_minute: Option<u32>,
570 #[serde(default = "default_max_tracked_keys")]
575 pub max_tracked_keys: usize,
576 #[serde(default = "default_idle_eviction", with = "humantime_serde")]
579 pub idle_eviction: Duration,
580}
581
582impl Default for RateLimitConfig {
583 fn default() -> Self {
584 Self {
585 max_attempts_per_minute: default_max_attempts(),
586 pre_auth_max_per_minute: None,
587 max_tracked_keys: default_max_tracked_keys(),
588 idle_eviction: default_idle_eviction(),
589 }
590 }
591}
592
593impl RateLimitConfig {
594 #[must_use]
598 pub fn new(max_attempts_per_minute: u32) -> Self {
599 Self {
600 max_attempts_per_minute,
601 ..Self::default()
602 }
603 }
604
605 #[must_use]
608 pub fn with_pre_auth_max_per_minute(mut self, quota: u32) -> Self {
609 self.pre_auth_max_per_minute = Some(quota);
610 self
611 }
612
613 #[must_use]
615 pub fn with_max_tracked_keys(mut self, max: usize) -> Self {
616 self.max_tracked_keys = max;
617 self
618 }
619
620 #[must_use]
622 pub fn with_idle_eviction(mut self, idle: Duration) -> Self {
623 self.idle_eviction = idle;
624 self
625 }
626}
627
628fn default_max_attempts() -> u32 {
629 30
630}
631
632fn default_max_tracked_keys() -> usize {
633 10_000
634}
635
636fn default_idle_eviction() -> Duration {
637 Duration::from_mins(15)
638}
639
640#[derive(Debug, Clone, Default, Deserialize)]
642#[non_exhaustive]
643pub struct AuthConfig {
644 #[serde(default)]
646 pub enabled: bool,
647 #[serde(default)]
649 pub api_keys: Vec<ApiKeyEntry>,
650 pub mtls: Option<MtlsConfig>,
652 pub rate_limit: Option<RateLimitConfig>,
654 #[cfg(feature = "oauth")]
656 pub oauth: Option<crate::oauth::OAuthConfig>,
657}
658
659impl AuthConfig {
660 #[must_use]
662 pub fn with_keys(keys: Vec<ApiKeyEntry>) -> Self {
663 Self {
664 enabled: true,
665 api_keys: keys,
666 mtls: None,
667 rate_limit: None,
668 #[cfg(feature = "oauth")]
669 oauth: None,
670 }
671 }
672
673 #[must_use]
675 pub fn with_rate_limit(mut self, rate_limit: RateLimitConfig) -> Self {
676 self.rate_limit = Some(rate_limit);
677 self
678 }
679}
680
681#[derive(Debug, Clone, serde::Serialize)]
685#[non_exhaustive]
686pub struct ApiKeySummary {
687 pub name: String,
689 pub role: String,
691 pub expires_at: Option<RfcTimestamp>,
694}
695
696#[derive(Debug, Clone, serde::Serialize)]
698#[allow(
699 clippy::struct_excessive_bools,
700 reason = "this is a flat summary of independent auth-method booleans"
701)]
702#[non_exhaustive]
703pub struct AuthConfigSummary {
704 pub enabled: bool,
706 pub bearer: bool,
708 pub mtls: bool,
710 pub oauth: bool,
712 pub api_keys: Vec<ApiKeySummary>,
714}
715
716impl AuthConfig {
717 #[must_use]
719 pub fn summary(&self) -> AuthConfigSummary {
720 AuthConfigSummary {
721 enabled: self.enabled,
722 bearer: !self.api_keys.is_empty(),
723 mtls: self.mtls.is_some(),
724 #[cfg(feature = "oauth")]
725 oauth: self.oauth.is_some(),
726 #[cfg(not(feature = "oauth"))]
727 oauth: false,
728 api_keys: self
729 .api_keys
730 .iter()
731 .map(|k| ApiKeySummary {
732 name: k.name.clone(),
733 role: k.role.clone(),
734 expires_at: k.expires_at,
735 })
736 .collect(),
737 }
738 }
739}
740
741pub(crate) type KeyedLimiter = BoundedKeyedLimiter<IpAddr>;
744
745#[derive(Clone, Debug)]
755#[non_exhaustive]
756pub(crate) struct TlsConnInfo {
757 pub addr: SocketAddr,
759 pub identity: Option<AuthIdentity>,
762}
763
764impl TlsConnInfo {
765 #[must_use]
767 pub(crate) const fn new(addr: SocketAddr, identity: Option<AuthIdentity>) -> Self {
768 Self { addr, identity }
769 }
770}
771
772#[allow(
777 missing_debug_implementations,
778 reason = "contains governor RateLimiter and JwksCache without Debug impls"
779)]
780#[non_exhaustive]
781pub(crate) struct AuthState {
782 pub api_keys: ArcSwap<Vec<ApiKeyEntry>>,
784 pub rate_limiter: Option<Arc<KeyedLimiter>>,
786 pub pre_auth_limiter: Option<Arc<KeyedLimiter>>,
789 #[cfg(feature = "oauth")]
790 pub jwks_cache: Option<Arc<crate::oauth::JwksCache>>,
792 pub seen_identities: Mutex<HashSet<String>>,
795 pub counters: AuthCounters,
797}
798
799impl AuthState {
800 pub(crate) fn reload_keys(&self, keys: Vec<ApiKeyEntry>) {
806 let count = keys.len();
807 self.api_keys.store(Arc::new(keys));
808 tracing::info!(keys = count, "API keys reloaded");
809 }
810
811 #[must_use]
813 pub(crate) fn counters_snapshot(&self) -> AuthCountersSnapshot {
814 self.counters.snapshot()
815 }
816
817 #[must_use]
819 pub(crate) fn api_key_summaries(&self) -> Vec<ApiKeySummary> {
820 self.api_keys
821 .load()
822 .iter()
823 .map(|k| ApiKeySummary {
824 name: k.name.clone(),
825 role: k.role.clone(),
826 expires_at: k.expires_at,
827 })
828 .collect()
829 }
830
831 fn log_auth(&self, id: &AuthIdentity, method: &str) {
833 self.counters.record_success(id.method);
834 let first = self
835 .seen_identities
836 .lock()
837 .unwrap_or_else(std::sync::PoisonError::into_inner)
838 .insert(id.name.clone());
839 if first {
840 tracing::info!(name = %id.name, role = %id.role, "{method} authenticated");
841 } else {
842 tracing::debug!(name = %id.name, role = %id.role, "{method} authenticated");
843 }
844 }
845}
846
847const DEFAULT_AUTH_RATE: NonZeroU32 = NonZeroU32::new(30).unwrap();
850
851#[must_use]
853pub(crate) fn build_rate_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
854 let quota = governor::Quota::per_minute(
855 NonZeroU32::new(config.max_attempts_per_minute).unwrap_or(DEFAULT_AUTH_RATE),
856 );
857 Arc::new(BoundedKeyedLimiter::new(
858 quota,
859 config.max_tracked_keys,
860 config.idle_eviction,
861 ))
862}
863
864#[must_use]
871pub(crate) fn build_pre_auth_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
872 let resolved = config.pre_auth_max_per_minute.unwrap_or_else(|| {
873 config
874 .max_attempts_per_minute
875 .saturating_mul(PRE_AUTH_DEFAULT_MULTIPLIER)
876 });
877 let quota =
878 governor::Quota::per_minute(NonZeroU32::new(resolved).unwrap_or(DEFAULT_PRE_AUTH_RATE));
879 Arc::new(BoundedKeyedLimiter::new(
880 quota,
881 config.max_tracked_keys,
882 config.idle_eviction,
883 ))
884}
885
886const PRE_AUTH_DEFAULT_MULTIPLIER: u32 = 10;
889
890const DEFAULT_PRE_AUTH_RATE: NonZeroU32 = NonZeroU32::new(300).unwrap();
894
895#[must_use]
900pub fn extract_mtls_identity(cert_der: &[u8], default_role: &str) -> Option<AuthIdentity> {
901 let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
902
903 let cn = cert
905 .subject()
906 .iter_common_name()
907 .next()
908 .and_then(|attr| attr.as_str().ok())
909 .map(String::from);
910
911 let name = cn.or_else(|| {
913 cert.subject_alternative_name()
914 .ok()
915 .flatten()
916 .and_then(|san| {
917 #[allow(clippy::wildcard_enum_match_arm)]
918 san.value.general_names.iter().find_map(|gn| match gn {
919 GeneralName::DNSName(dns) => Some((*dns).to_owned()),
920 _ => None,
921 })
922 })
923 })?;
924
925 if !name
927 .chars()
928 .all(|c| c.is_alphanumeric() || matches!(c, '-' | '.' | '_' | '@'))
929 {
930 tracing::warn!(cn = %name, "mTLS identity rejected: invalid characters in CN/SAN");
931 return None;
932 }
933
934 Some(AuthIdentity {
935 name,
936 role: default_role.to_owned(),
937 method: AuthMethod::MtlsCertificate,
938 raw_token: None,
939 sub: None,
940 })
941}
942
943fn extract_bearer(value: &str) -> Option<&str> {
958 let (scheme, rest) = value.split_once(' ')?;
959 if scheme.eq_ignore_ascii_case("Bearer") {
960 let token = rest.trim_start_matches(' ');
961 if token.is_empty() { None } else { Some(token) }
962 } else {
963 None
964 }
965}
966
967#[must_use]
996pub fn verify_bearer_token(token: &str, keys: &[ApiKeyEntry]) -> Option<AuthIdentity> {
997 use subtle::ConstantTimeEq as _;
998
999 let now = chrono::Utc::now();
1000 #[allow(
1001 clippy::expect_used,
1002 reason = "DUMMY_PHC_HASH is a static LazyLock built from a fixed Argon2id PHC string by construction; PasswordHash::new on it is infallible. See DUMMY_PHC_HASH definition."
1003 )]
1004 let dummy_hash = PasswordHash::new(&DUMMY_PHC_HASH)
1005 .expect("DUMMY_PHC_HASH is a valid Argon2id PHC string by construction");
1006
1007 let mut matched_index: usize = usize::MAX;
1008 let mut any_match: u8 = 0;
1009
1010 for (idx, key) in keys.iter().enumerate() {
1011 let expired = key.expires_at.is_some_and(|exp| exp.as_datetime() < &now);
1012
1013 let real_hash = PasswordHash::new(&key.hash);
1014 let verify_against = match (&real_hash, expired, any_match) {
1015 (Ok(h), false, 0) => h,
1016 _ => &dummy_hash,
1017 };
1018
1019 let slot_ok = u8::from(
1020 Argon2::default()
1021 .verify_password(token.as_bytes(), verify_against)
1022 .is_ok(),
1023 );
1024
1025 let real_match = slot_ok & u8::from(!expired) & u8::from(real_hash.is_ok());
1026 let first_real_match = real_match & (1 - any_match);
1027 if first_real_match.ct_eq(&1).into() {
1028 matched_index = idx;
1029 }
1030 any_match |= real_match;
1031 }
1032
1033 if any_match == 0 {
1034 return None;
1035 }
1036 let key = keys.get(matched_index)?;
1037 Some(AuthIdentity {
1038 name: key.name.clone(),
1039 role: key.role.clone(),
1040 method: AuthMethod::BearerToken,
1041 raw_token: None,
1042 sub: None,
1043 })
1044}
1045
1046static DUMMY_PHC_HASH: LazyLock<String> = LazyLock::new(|| {
1059 #[allow(
1061 clippy::expect_used,
1062 reason = "fixed 22-char base64 ('AAAA...') decodes to a valid 16-byte salt; SaltString::from_b64 is infallible on this literal"
1063 )]
1064 let salt = SaltString::from_b64("AAAAAAAAAAAAAAAAAAAAAA")
1065 .expect("fixed 16-byte base64 salt is well-formed");
1066 #[allow(
1067 clippy::expect_used,
1068 reason = "Argon2::default() with a fixed plaintext and a well-formed salt is infallible; only fails on bad params/salt"
1069 )]
1070 Argon2::default()
1071 .hash_password(b"rmcp-server-kit-dummy", &salt)
1072 .expect("Argon2 default params hash a fixed plaintext")
1073 .to_string()
1074});
1075
1076pub fn generate_api_key() -> Result<(String, String), McpxError> {
1086 let mut token_bytes = [0u8; 32];
1087 rand::fill(&mut token_bytes);
1088 let token = URL_SAFE_NO_PAD.encode(token_bytes);
1089
1090 let mut salt_bytes = [0u8; 16];
1092 rand::fill(&mut salt_bytes);
1093 let salt = SaltString::encode_b64(&salt_bytes)
1094 .map_err(|e| McpxError::Auth(format!("salt encoding failed: {e}")))?;
1095 let hash = Argon2::default()
1096 .hash_password(token.as_bytes(), &salt)
1097 .map_err(|e| McpxError::Auth(format!("argon2id hashing failed: {e}")))?
1098 .to_string();
1099
1100 Ok((token, hash))
1101}
1102
1103fn build_www_authenticate_value(
1104 advertise_resource_metadata: bool,
1105 failure: AuthFailureClass,
1106) -> String {
1107 let (error, error_description) = failure.bearer_error();
1108 if advertise_resource_metadata {
1109 return format!(
1110 "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\", error=\"{error}\", error_description=\"{error_description}\""
1111 );
1112 }
1113 format!("Bearer error=\"{error}\", error_description=\"{error_description}\"")
1114}
1115
1116fn auth_method_label(method: AuthMethod) -> &'static str {
1117 match method {
1118 AuthMethod::MtlsCertificate => "mTLS",
1119 AuthMethod::BearerToken => "bearer token",
1120 AuthMethod::OAuthJwt => "OAuth JWT",
1121 }
1122}
1123
1124#[cfg_attr(not(feature = "oauth"), allow(unused_variables))]
1125fn unauthorized_response(state: &AuthState, failure_class: AuthFailureClass) -> Response {
1126 #[cfg(feature = "oauth")]
1127 let advertise_resource_metadata = state.jwks_cache.is_some();
1128 #[cfg(not(feature = "oauth"))]
1129 let advertise_resource_metadata = false;
1130
1131 let challenge = build_www_authenticate_value(advertise_resource_metadata, failure_class);
1132 (
1133 axum::http::StatusCode::UNAUTHORIZED,
1134 [(header::WWW_AUTHENTICATE, challenge)],
1135 failure_class.response_body(),
1136 )
1137 .into_response()
1138}
1139
1140async fn authenticate_bearer_identity(
1141 state: &AuthState,
1142 token: &str,
1143) -> Result<AuthIdentity, AuthFailureClass> {
1144 let mut failure_class = AuthFailureClass::MissingCredential;
1145
1146 #[cfg(feature = "oauth")]
1147 if let Some(ref cache) = state.jwks_cache
1148 && crate::oauth::looks_like_jwt(token)
1149 {
1150 match cache.validate_token_with_reason(token).await {
1151 Ok(mut id) => {
1152 id.raw_token = Some(SecretString::from(token.to_owned()));
1153 return Ok(id);
1154 }
1155 Err(crate::oauth::JwtValidationFailure::Expired) => {
1156 failure_class = AuthFailureClass::ExpiredCredential;
1157 }
1158 Err(crate::oauth::JwtValidationFailure::Invalid) => {
1159 failure_class = AuthFailureClass::InvalidCredential;
1160 }
1161 }
1162 }
1163
1164 let token = token.to_owned();
1165 let keys = state.api_keys.load_full(); let identity = tokio::task::spawn_blocking(move || verify_bearer_token(&token, &keys))
1169 .await
1170 .ok()
1171 .flatten();
1172
1173 if let Some(id) = identity {
1174 return Ok(id);
1175 }
1176
1177 if failure_class == AuthFailureClass::MissingCredential {
1178 failure_class = AuthFailureClass::InvalidCredential;
1179 }
1180
1181 Err(failure_class)
1182}
1183
1184fn pre_auth_gate(state: &AuthState, peer_addr: Option<SocketAddr>) -> Option<Response> {
1195 let limiter = state.pre_auth_limiter.as_ref()?;
1196 let addr = peer_addr?;
1197 if limiter.check_key(&addr.ip()).is_ok() {
1198 return None;
1199 }
1200 state.counters.record_failure(AuthFailureClass::PreAuthGate);
1201 tracing::warn!(
1202 ip = %addr.ip(),
1203 "auth rate limited by pre-auth gate (request rejected before credential verification)"
1204 );
1205 Some(
1206 McpxError::RateLimited("too many unauthenticated requests from this source".into())
1207 .into_response(),
1208 )
1209}
1210
1211pub(crate) async fn auth_middleware(
1220 state: Arc<AuthState>,
1221 req: Request<Body>,
1222 next: Next,
1223) -> Response {
1224 let tls_info = req.extensions().get::<ConnectInfo<TlsConnInfo>>().cloned();
1229 let peer_addr = req
1230 .extensions()
1231 .get::<ConnectInfo<SocketAddr>>()
1232 .map(|ci| ci.0)
1233 .or_else(|| tls_info.as_ref().map(|ci| ci.0.addr));
1234
1235 if let Some(id) = tls_info.and_then(|ci| ci.0.identity) {
1242 state.log_auth(&id, "mTLS");
1243 let mut req = req;
1244 req.extensions_mut().insert(id);
1245 return next.run(req).await;
1246 }
1247
1248 if let Some(blocked) = pre_auth_gate(&state, peer_addr) {
1252 return blocked;
1253 }
1254
1255 let failure_class = if let Some(value) = req.headers().get(header::AUTHORIZATION) {
1256 match value.to_str().ok().and_then(extract_bearer) {
1257 Some(token) => match authenticate_bearer_identity(&state, token).await {
1258 Ok(id) => {
1259 state.log_auth(&id, auth_method_label(id.method));
1260 let mut req = req;
1261 req.extensions_mut().insert(id);
1262 return next.run(req).await;
1263 }
1264 Err(class) => class,
1265 },
1266 None => AuthFailureClass::InvalidCredential,
1267 }
1268 } else {
1269 AuthFailureClass::MissingCredential
1270 };
1271
1272 tracing::warn!(failure_class = %failure_class.as_str(), "auth failed");
1273
1274 if let (Some(limiter), Some(addr)) = (&state.rate_limiter, peer_addr)
1277 && limiter.check_key(&addr.ip()).is_err()
1278 {
1279 state.counters.record_failure(AuthFailureClass::RateLimited);
1280 tracing::warn!(ip = %addr.ip(), "auth rate limited after repeated failures");
1281 return McpxError::RateLimited("too many failed authentication attempts".into())
1282 .into_response();
1283 }
1284
1285 state.counters.record_failure(failure_class);
1286 unauthorized_response(&state, failure_class)
1287}
1288
1289#[cfg(test)]
1290mod tests {
1291 use super::*;
1292
1293 #[test]
1294 fn generate_and_verify_api_key() {
1295 let (token, hash) = generate_api_key().unwrap();
1296
1297 assert_eq!(token.len(), 43);
1299
1300 assert!(hash.starts_with("$argon2id$"));
1302
1303 let keys = vec![ApiKeyEntry {
1305 name: "test".into(),
1306 hash,
1307 role: "viewer".into(),
1308 expires_at: None,
1309 }];
1310 let id = verify_bearer_token(&token, &keys);
1311 assert!(id.is_some());
1312 let id = id.unwrap();
1313 assert_eq!(id.name, "test");
1314 assert_eq!(id.role, "viewer");
1315 assert_eq!(id.method, AuthMethod::BearerToken);
1316 }
1317
1318 #[test]
1319 fn wrong_token_rejected() {
1320 let (_token, hash) = generate_api_key().unwrap();
1321 let keys = vec![ApiKeyEntry {
1322 name: "test".into(),
1323 hash,
1324 role: "viewer".into(),
1325 expires_at: None,
1326 }];
1327 assert!(verify_bearer_token("wrong-token", &keys).is_none());
1328 }
1329
1330 #[test]
1331 fn expired_key_rejected() {
1332 let (token, hash) = generate_api_key().unwrap();
1333 let keys = vec![ApiKeyEntry {
1334 name: "test".into(),
1335 hash,
1336 role: "viewer".into(),
1337 expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1338 }];
1339 assert!(verify_bearer_token(&token, &keys).is_none());
1340 }
1341
1342 #[test]
1343 fn match_in_last_slot_still_authenticates() {
1344 let (token, hash) = generate_api_key().unwrap();
1345 let (_other_token, other_hash) = generate_api_key().unwrap();
1346 let keys = vec![
1347 ApiKeyEntry {
1348 name: "first".into(),
1349 hash: other_hash.clone(),
1350 role: "viewer".into(),
1351 expires_at: None,
1352 },
1353 ApiKeyEntry {
1354 name: "second".into(),
1355 hash: other_hash,
1356 role: "viewer".into(),
1357 expires_at: None,
1358 },
1359 ApiKeyEntry {
1360 name: "match".into(),
1361 hash,
1362 role: "ops".into(),
1363 expires_at: None,
1364 },
1365 ];
1366 let id = verify_bearer_token(&token, &keys).expect("last-slot match must authenticate");
1367 assert_eq!(id.name, "match");
1368 assert_eq!(id.role, "ops");
1369 }
1370
1371 #[test]
1372 fn expired_slot_before_valid_match_does_not_short_circuit() {
1373 let (token, hash) = generate_api_key().unwrap();
1374 let (_, other_hash) = generate_api_key().unwrap();
1375 let keys = vec![
1376 ApiKeyEntry {
1377 name: "expired".into(),
1378 hash: other_hash,
1379 role: "viewer".into(),
1380 expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1381 },
1382 ApiKeyEntry {
1383 name: "valid".into(),
1384 hash,
1385 role: "ops".into(),
1386 expires_at: None,
1387 },
1388 ];
1389 let id = verify_bearer_token(&token, &keys)
1390 .expect("valid slot following an expired slot must authenticate");
1391 assert_eq!(id.name, "valid");
1392 }
1393
1394 #[test]
1395 fn malformed_hash_slot_does_not_short_circuit() {
1396 let (token, hash) = generate_api_key().unwrap();
1397 let keys = vec![
1398 ApiKeyEntry {
1399 name: "broken".into(),
1400 hash: "this-is-not-a-phc-string".into(),
1401 role: "viewer".into(),
1402 expires_at: None,
1403 },
1404 ApiKeyEntry {
1405 name: "valid".into(),
1406 hash,
1407 role: "ops".into(),
1408 expires_at: None,
1409 },
1410 ];
1411 let id = verify_bearer_token(&token, &keys)
1412 .expect("valid slot following a malformed-hash slot must authenticate");
1413 assert_eq!(id.name, "valid");
1414 }
1415
1416 #[test]
1427 fn rfc_timestamp_parse_rejects_malformed() {
1428 for bad in [
1429 "not-a-date",
1430 "",
1431 "2025-13-01T00:00:00Z", "2025-01-32T00:00:00Z", "2025-01-01T00:00:00", "01/01/2025", "2025-01-01T25:00:00Z", ] {
1437 assert!(
1438 RfcTimestamp::parse(bad).is_err(),
1439 "RfcTimestamp::parse must reject {bad:?}"
1440 );
1441 }
1442 }
1443
1444 #[test]
1445 fn rfc_timestamp_parse_accepts_valid() {
1446 for good in [
1447 "2025-01-01T00:00:00Z",
1448 "2025-01-01T00:00:00+00:00",
1449 "2025-12-31T23:59:59-08:00",
1450 "2099-01-01T00:00:00.123456789Z",
1451 ] {
1452 assert!(
1453 RfcTimestamp::parse(good).is_ok(),
1454 "RfcTimestamp::parse must accept {good:?}"
1455 );
1456 }
1457 }
1458
1459 #[test]
1460 fn api_key_entry_deserialize_rejects_malformed_expires_at() {
1461 let toml = r#"
1466 name = "bad-key"
1467 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1468 role = "viewer"
1469 expires_at = "not-a-date"
1470 "#;
1471 let result: Result<ApiKeyEntry, _> = toml::from_str(toml);
1472 assert!(
1473 result.is_err(),
1474 "deserialization must reject malformed expires_at"
1475 );
1476 }
1477
1478 #[test]
1479 fn api_key_entry_deserialize_accepts_valid_expires_at() {
1480 let toml = r#"
1481 name = "good-key"
1482 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1483 role = "viewer"
1484 expires_at = "2099-01-01T00:00:00Z"
1485 "#;
1486 let entry: ApiKeyEntry = toml::from_str(toml).expect("valid RFC 3339 must deserialize");
1487 assert!(entry.expires_at.is_some());
1488 }
1489
1490 #[test]
1491 fn api_key_entry_deserialize_accepts_missing_expires_at() {
1492 let toml = r#"
1495 name = "eternal-key"
1496 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1497 role = "viewer"
1498 "#;
1499 let entry: ApiKeyEntry = toml::from_str(toml).expect("missing expires_at must deserialize");
1500 assert!(entry.expires_at.is_none());
1501 }
1502
1503 #[test]
1504 fn try_with_expiry_rejects_malformed() {
1505 let entry = ApiKeyEntry::new("k", "hash", "viewer");
1506 assert!(entry.try_with_expiry("not-a-date").is_err());
1507 }
1508
1509 #[test]
1510 fn try_with_expiry_accepts_valid() {
1511 let entry = ApiKeyEntry::new("k", "hash", "viewer")
1512 .try_with_expiry("2099-01-01T00:00:00Z")
1513 .expect("valid RFC 3339 must be accepted");
1514 assert!(entry.expires_at.is_some());
1515 }
1516
1517 #[test]
1518 fn api_key_summary_serializes_expires_at_as_rfc3339() {
1519 let summary = ApiKeySummary {
1524 name: "k".into(),
1525 role: "viewer".into(),
1526 expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
1527 };
1528 let json = serde_json::to_string(&summary).unwrap();
1529 assert!(
1530 json.contains(r#""expires_at":"2030-01-01T00:00:00+00:00""#),
1531 "wire format regressed: {json}"
1532 );
1533 }
1534
1535 #[test]
1536 fn future_expiry_accepted() {
1537 let (token, hash) = generate_api_key().unwrap();
1538 let keys = vec![ApiKeyEntry {
1539 name: "test".into(),
1540 hash,
1541 role: "viewer".into(),
1542 expires_at: Some(RfcTimestamp::parse("2099-01-01T00:00:00Z").unwrap()),
1543 }];
1544 assert!(verify_bearer_token(&token, &keys).is_some());
1545 }
1546
1547 #[test]
1548 fn multiple_keys_first_match_wins() {
1549 let (token, hash) = generate_api_key().unwrap();
1550 let keys = vec![
1551 ApiKeyEntry {
1552 name: "wrong".into(),
1553 hash: "$argon2id$v=19$m=19456,t=2,p=1$invalid$invalid".into(),
1554 role: "ops".into(),
1555 expires_at: None,
1556 },
1557 ApiKeyEntry {
1558 name: "correct".into(),
1559 hash,
1560 role: "deploy".into(),
1561 expires_at: None,
1562 },
1563 ];
1564 let id = verify_bearer_token(&token, &keys).unwrap();
1565 assert_eq!(id.name, "correct");
1566 assert_eq!(id.role, "deploy");
1567 }
1568
1569 #[test]
1570 fn rate_limiter_allows_within_quota() {
1571 let config = RateLimitConfig {
1572 max_attempts_per_minute: 5,
1573 pre_auth_max_per_minute: None,
1574 ..Default::default()
1575 };
1576 let limiter = build_rate_limiter(&config);
1577 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1578
1579 for _ in 0..5 {
1581 assert!(limiter.check_key(&ip).is_ok());
1582 }
1583 assert!(limiter.check_key(&ip).is_err());
1585 }
1586
1587 #[test]
1588 fn rate_limiter_separate_ips() {
1589 let config = RateLimitConfig {
1590 max_attempts_per_minute: 2,
1591 pre_auth_max_per_minute: None,
1592 ..Default::default()
1593 };
1594 let limiter = build_rate_limiter(&config);
1595 let ip1: IpAddr = "10.0.0.1".parse().unwrap();
1596 let ip2: IpAddr = "10.0.0.2".parse().unwrap();
1597
1598 assert!(limiter.check_key(&ip1).is_ok());
1600 assert!(limiter.check_key(&ip1).is_ok());
1601 assert!(limiter.check_key(&ip1).is_err());
1602
1603 assert!(limiter.check_key(&ip2).is_ok());
1605 }
1606
1607 #[test]
1608 fn extract_mtls_identity_from_cn() {
1609 let mut params = rcgen::CertificateParams::new(vec!["test-client.local".into()]).unwrap();
1611 params.distinguished_name = rcgen::DistinguishedName::new();
1612 params
1613 .distinguished_name
1614 .push(rcgen::DnType::CommonName, "test-client");
1615 let cert = params
1616 .self_signed(&rcgen::KeyPair::generate().unwrap())
1617 .unwrap();
1618 let der = cert.der();
1619
1620 let id = extract_mtls_identity(der, "ops").unwrap();
1621 assert_eq!(id.name, "test-client");
1622 assert_eq!(id.role, "ops");
1623 assert_eq!(id.method, AuthMethod::MtlsCertificate);
1624 }
1625
1626 #[test]
1627 fn extract_mtls_identity_falls_back_to_san() {
1628 let mut params =
1630 rcgen::CertificateParams::new(vec!["san-only.example.com".into()]).unwrap();
1631 params.distinguished_name = rcgen::DistinguishedName::new();
1632 let cert = params
1634 .self_signed(&rcgen::KeyPair::generate().unwrap())
1635 .unwrap();
1636 let der = cert.der();
1637
1638 let id = extract_mtls_identity(der, "viewer").unwrap();
1639 assert_eq!(id.name, "san-only.example.com");
1640 assert_eq!(id.role, "viewer");
1641 }
1642
1643 #[test]
1644 fn extract_mtls_identity_invalid_der() {
1645 assert!(extract_mtls_identity(b"not-a-cert", "viewer").is_none());
1646 }
1647
1648 use axum::{
1651 body::Body,
1652 http::{Request, StatusCode},
1653 };
1654 use tower::ServiceExt as _;
1655
1656 fn auth_router(state: Arc<AuthState>) -> axum::Router {
1657 axum::Router::new()
1658 .route("/mcp", axum::routing::post(|| async { "ok" }))
1659 .layer(axum::middleware::from_fn(move |req, next| {
1660 let s = Arc::clone(&state);
1661 auth_middleware(s, req, next)
1662 }))
1663 }
1664
1665 fn test_auth_state(keys: Vec<ApiKeyEntry>) -> Arc<AuthState> {
1666 Arc::new(AuthState {
1667 api_keys: ArcSwap::new(Arc::new(keys)),
1668 rate_limiter: None,
1669 pre_auth_limiter: None,
1670 #[cfg(feature = "oauth")]
1671 jwks_cache: None,
1672 seen_identities: Mutex::new(HashSet::new()),
1673 counters: AuthCounters::default(),
1674 })
1675 }
1676
1677 #[tokio::test]
1678 async fn middleware_rejects_no_credentials() {
1679 let state = test_auth_state(vec![]);
1680 let app = auth_router(Arc::clone(&state));
1681 let req = Request::builder()
1682 .method(axum::http::Method::POST)
1683 .uri("/mcp")
1684 .body(Body::empty())
1685 .unwrap();
1686 let resp = app.oneshot(req).await.unwrap();
1687 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1688 let challenge = resp
1689 .headers()
1690 .get(header::WWW_AUTHENTICATE)
1691 .unwrap()
1692 .to_str()
1693 .unwrap();
1694 assert!(challenge.contains("error=\"invalid_request\""));
1695
1696 let counters = state.counters_snapshot();
1697 assert_eq!(counters.failure_missing_credential, 1);
1698 }
1699
1700 #[tokio::test]
1701 async fn middleware_accepts_valid_bearer() {
1702 let (token, hash) = generate_api_key().unwrap();
1703 let keys = vec![ApiKeyEntry {
1704 name: "test-key".into(),
1705 hash,
1706 role: "ops".into(),
1707 expires_at: None,
1708 }];
1709 let state = test_auth_state(keys);
1710 let app = auth_router(Arc::clone(&state));
1711 let req = Request::builder()
1712 .method(axum::http::Method::POST)
1713 .uri("/mcp")
1714 .header("authorization", format!("Bearer {token}"))
1715 .body(Body::empty())
1716 .unwrap();
1717 let resp = app.oneshot(req).await.unwrap();
1718 assert_eq!(resp.status(), StatusCode::OK);
1719
1720 let counters = state.counters_snapshot();
1721 assert_eq!(counters.success_bearer, 1);
1722 }
1723
1724 #[tokio::test]
1725 async fn middleware_rejects_wrong_bearer() {
1726 let (_token, hash) = generate_api_key().unwrap();
1727 let keys = vec![ApiKeyEntry {
1728 name: "test-key".into(),
1729 hash,
1730 role: "ops".into(),
1731 expires_at: None,
1732 }];
1733 let state = test_auth_state(keys);
1734 let app = auth_router(Arc::clone(&state));
1735 let req = Request::builder()
1736 .method(axum::http::Method::POST)
1737 .uri("/mcp")
1738 .header("authorization", "Bearer wrong-token-here")
1739 .body(Body::empty())
1740 .unwrap();
1741 let resp = app.oneshot(req).await.unwrap();
1742 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1743 let challenge = resp
1744 .headers()
1745 .get(header::WWW_AUTHENTICATE)
1746 .unwrap()
1747 .to_str()
1748 .unwrap();
1749 assert!(challenge.contains("error=\"invalid_token\""));
1750
1751 let counters = state.counters_snapshot();
1752 assert_eq!(counters.failure_invalid_credential, 1);
1753 }
1754
1755 #[tokio::test]
1756 async fn middleware_rate_limits() {
1757 let state = Arc::new(AuthState {
1758 api_keys: ArcSwap::new(Arc::new(vec![])),
1759 rate_limiter: Some(build_rate_limiter(&RateLimitConfig {
1760 max_attempts_per_minute: 1,
1761 pre_auth_max_per_minute: None,
1762 ..Default::default()
1763 })),
1764 pre_auth_limiter: None,
1765 #[cfg(feature = "oauth")]
1766 jwks_cache: None,
1767 seen_identities: Mutex::new(HashSet::new()),
1768 counters: AuthCounters::default(),
1769 });
1770 let app = auth_router(state);
1771
1772 let req = Request::builder()
1774 .method(axum::http::Method::POST)
1775 .uri("/mcp")
1776 .body(Body::empty())
1777 .unwrap();
1778 let resp = app.clone().oneshot(req).await.unwrap();
1779 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1780
1781 }
1786
1787 #[test]
1793 fn rate_limit_semantics_failed_only() {
1794 let config = RateLimitConfig {
1795 max_attempts_per_minute: 3,
1796 pre_auth_max_per_minute: None,
1797 ..Default::default()
1798 };
1799 let limiter = build_rate_limiter(&config);
1800 let ip: IpAddr = "192.168.1.100".parse().unwrap();
1801
1802 assert!(
1804 limiter.check_key(&ip).is_ok(),
1805 "failure 1 should be allowed"
1806 );
1807 assert!(
1808 limiter.check_key(&ip).is_ok(),
1809 "failure 2 should be allowed"
1810 );
1811 assert!(
1812 limiter.check_key(&ip).is_ok(),
1813 "failure 3 should be allowed"
1814 );
1815 assert!(
1816 limiter.check_key(&ip).is_err(),
1817 "failure 4 should be blocked"
1818 );
1819
1820 }
1829
1830 #[test]
1835 fn pre_auth_default_multiplier_is_10x() {
1836 let config = RateLimitConfig {
1837 max_attempts_per_minute: 5,
1838 pre_auth_max_per_minute: None,
1839 ..Default::default()
1840 };
1841 let limiter = build_pre_auth_limiter(&config);
1842 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1843
1844 for i in 0..50 {
1846 assert!(
1847 limiter.check_key(&ip).is_ok(),
1848 "pre-auth attempt {i} (of expected 50) should be allowed under default 10x multiplier"
1849 );
1850 }
1851 assert!(
1853 limiter.check_key(&ip).is_err(),
1854 "pre-auth attempt 51 should be blocked (quota is 50, not unbounded)"
1855 );
1856 }
1857
1858 #[test]
1861 fn pre_auth_explicit_override_wins() {
1862 let config = RateLimitConfig {
1863 max_attempts_per_minute: 100, pre_auth_max_per_minute: Some(2), ..Default::default()
1866 };
1867 let limiter = build_pre_auth_limiter(&config);
1868 let ip: IpAddr = "10.0.0.2".parse().unwrap();
1869
1870 assert!(limiter.check_key(&ip).is_ok(), "attempt 1 allowed");
1871 assert!(limiter.check_key(&ip).is_ok(), "attempt 2 allowed");
1872 assert!(
1873 limiter.check_key(&ip).is_err(),
1874 "attempt 3 must be blocked (explicit override of 2 wins over 10x default of 1000)"
1875 );
1876 }
1877
1878 #[tokio::test]
1884 async fn pre_auth_gate_blocks_before_argon2_verification() {
1885 let (_token, hash) = generate_api_key().unwrap();
1886 let keys = vec![ApiKeyEntry {
1887 name: "test-key".into(),
1888 hash,
1889 role: "ops".into(),
1890 expires_at: None,
1891 }];
1892 let config = RateLimitConfig {
1893 max_attempts_per_minute: 100,
1894 pre_auth_max_per_minute: Some(1),
1895 ..Default::default()
1896 };
1897 let state = Arc::new(AuthState {
1898 api_keys: ArcSwap::new(Arc::new(keys)),
1899 rate_limiter: None,
1900 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1901 #[cfg(feature = "oauth")]
1902 jwks_cache: None,
1903 seen_identities: Mutex::new(HashSet::new()),
1904 counters: AuthCounters::default(),
1905 });
1906 let app = auth_router(Arc::clone(&state));
1907 let peer: SocketAddr = "10.0.0.10:54321".parse().unwrap();
1908
1909 let mut req1 = Request::builder()
1912 .method(axum::http::Method::POST)
1913 .uri("/mcp")
1914 .header("authorization", "Bearer obviously-not-a-real-token")
1915 .body(Body::empty())
1916 .unwrap();
1917 req1.extensions_mut().insert(ConnectInfo(peer));
1918 let resp1 = app.clone().oneshot(req1).await.unwrap();
1919 assert_eq!(
1920 resp1.status(),
1921 StatusCode::UNAUTHORIZED,
1922 "first attempt: gate has quota, falls through to bearer auth which fails with 401"
1923 );
1924
1925 let mut req2 = Request::builder()
1928 .method(axum::http::Method::POST)
1929 .uri("/mcp")
1930 .header("authorization", "Bearer also-not-a-real-token")
1931 .body(Body::empty())
1932 .unwrap();
1933 req2.extensions_mut().insert(ConnectInfo(peer));
1934 let resp2 = app.oneshot(req2).await.unwrap();
1935 assert_eq!(
1936 resp2.status(),
1937 StatusCode::TOO_MANY_REQUESTS,
1938 "second attempt from same IP: pre-auth gate must reject with 429"
1939 );
1940
1941 let counters = state.counters_snapshot();
1942 assert_eq!(
1943 counters.failure_pre_auth_gate, 1,
1944 "exactly one request must have been rejected by the pre-auth gate"
1945 );
1946 assert_eq!(
1950 counters.failure_invalid_credential, 1,
1951 "bearer verification must run exactly once (only the un-gated first request)"
1952 );
1953 }
1954
1955 #[tokio::test]
1962 async fn pre_auth_gate_does_not_throttle_mtls() {
1963 let config = RateLimitConfig {
1964 max_attempts_per_minute: 100,
1965 pre_auth_max_per_minute: Some(1), ..Default::default()
1967 };
1968 let state = Arc::new(AuthState {
1969 api_keys: ArcSwap::new(Arc::new(vec![])),
1970 rate_limiter: None,
1971 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1972 #[cfg(feature = "oauth")]
1973 jwks_cache: None,
1974 seen_identities: Mutex::new(HashSet::new()),
1975 counters: AuthCounters::default(),
1976 });
1977 let app = auth_router(Arc::clone(&state));
1978 let peer: SocketAddr = "10.0.0.20:54321".parse().unwrap();
1979 let identity = AuthIdentity {
1980 name: "cn=test-client".into(),
1981 role: "viewer".into(),
1982 method: AuthMethod::MtlsCertificate,
1983 raw_token: None,
1984 sub: None,
1985 };
1986 let tls_info = TlsConnInfo::new(peer, Some(identity));
1987
1988 for i in 0..3 {
1989 let mut req = Request::builder()
1990 .method(axum::http::Method::POST)
1991 .uri("/mcp")
1992 .body(Body::empty())
1993 .unwrap();
1994 req.extensions_mut().insert(ConnectInfo(tls_info.clone()));
1995 let resp = app.clone().oneshot(req).await.unwrap();
1996 assert_eq!(
1997 resp.status(),
1998 StatusCode::OK,
1999 "mTLS request {i} must succeed: pre-auth gate must not apply to mTLS callers"
2000 );
2001 }
2002
2003 let counters = state.counters_snapshot();
2004 assert_eq!(
2005 counters.failure_pre_auth_gate, 0,
2006 "pre-auth gate counter must remain at zero: mTLS bypasses the gate"
2007 );
2008 assert_eq!(
2009 counters.success_mtls, 3,
2010 "all three mTLS requests must have been counted as successful"
2011 );
2012 }
2013
2014 #[test]
2019 fn extract_bearer_accepts_canonical_case() {
2020 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2021 }
2022
2023 #[test]
2024 fn extract_bearer_is_case_insensitive_per_rfc7235() {
2025 for header in &[
2029 "bearer abc123",
2030 "BEARER abc123",
2031 "BeArEr abc123",
2032 "bEaReR abc123",
2033 ] {
2034 assert_eq!(
2035 extract_bearer(header),
2036 Some("abc123"),
2037 "header {header:?} must parse as a Bearer token (RFC 7235 §2.1)"
2038 );
2039 }
2040 }
2041
2042 #[test]
2043 fn extract_bearer_rejects_other_schemes() {
2044 assert_eq!(extract_bearer("Basic dXNlcjpwYXNz"), None);
2045 assert_eq!(extract_bearer("Digest username=\"x\""), None);
2046 assert_eq!(extract_bearer("Token abc123"), None);
2047 }
2048
2049 #[test]
2050 fn extract_bearer_rejects_malformed() {
2051 assert_eq!(extract_bearer(""), None);
2053 assert_eq!(extract_bearer("Bearer"), None);
2054 assert_eq!(extract_bearer("Bearer "), None);
2055 assert_eq!(extract_bearer("Bearer "), None);
2056 }
2057
2058 #[test]
2059 fn extract_bearer_tolerates_extra_separator_whitespace() {
2060 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2062 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2063 }
2064
2065 #[test]
2071 fn auth_identity_debug_redacts_raw_token() {
2072 let id = AuthIdentity {
2073 name: "alice".into(),
2074 role: "admin".into(),
2075 method: AuthMethod::OAuthJwt,
2076 raw_token: Some(SecretString::from("super-secret-jwt-payload-xyz")),
2077 sub: Some("keycloak-uuid-2f3c8b".into()),
2078 };
2079 let dbg = format!("{id:?}");
2080
2081 assert!(dbg.contains("alice"), "name should be visible: {dbg}");
2083 assert!(dbg.contains("admin"), "role should be visible: {dbg}");
2084 assert!(dbg.contains("OAuthJwt"), "method should be visible: {dbg}");
2085
2086 assert!(
2088 !dbg.contains("super-secret-jwt-payload-xyz"),
2089 "raw_token must be redacted in Debug output: {dbg}"
2090 );
2091 assert!(
2092 !dbg.contains("keycloak-uuid-2f3c8b"),
2093 "sub must be redacted in Debug output: {dbg}"
2094 );
2095 assert!(
2096 dbg.contains("<redacted>"),
2097 "redaction marker missing: {dbg}"
2098 );
2099 }
2100
2101 #[test]
2102 fn auth_identity_debug_marks_absent_secrets() {
2103 let id = AuthIdentity {
2106 name: "viewer-key".into(),
2107 role: "viewer".into(),
2108 method: AuthMethod::BearerToken,
2109 raw_token: None,
2110 sub: None,
2111 };
2112 let dbg = format!("{id:?}");
2113 assert!(
2114 dbg.contains("<none>"),
2115 "absent secrets should be marked: {dbg}"
2116 );
2117 assert!(
2118 !dbg.contains("<redacted>"),
2119 "no <redacted> marker when secrets are absent: {dbg}"
2120 );
2121 }
2122
2123 #[test]
2124 fn api_key_entry_debug_redacts_hash() {
2125 let entry = ApiKeyEntry {
2126 name: "viewer-key".into(),
2127 hash: "$argon2id$v=19$m=19456,t=2,p=1$c2FsdHNhbHQ$h4sh3dPa55w0rd".into(),
2129 role: "viewer".into(),
2130 expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
2131 };
2132 let dbg = format!("{entry:?}");
2133
2134 assert!(dbg.contains("viewer-key"));
2136 assert!(dbg.contains("viewer"));
2137 assert!(dbg.contains("2030-01-01T00:00:00+00:00"));
2138
2139 assert!(
2141 !dbg.contains("$argon2id$"),
2142 "argon2 hash leaked into Debug output: {dbg}"
2143 );
2144 assert!(
2145 !dbg.contains("h4sh3dPa55w0rd"),
2146 "hash digest leaked into Debug output: {dbg}"
2147 );
2148 assert!(
2149 dbg.contains("<redacted>"),
2150 "redaction marker missing: {dbg}"
2151 );
2152 }
2153
2154 #[test]
2165 fn auth_failure_class_as_str_exact_strings() {
2166 assert_eq!(
2167 AuthFailureClass::MissingCredential.as_str(),
2168 "missing_credential"
2169 );
2170 assert_eq!(
2171 AuthFailureClass::InvalidCredential.as_str(),
2172 "invalid_credential"
2173 );
2174 assert_eq!(
2175 AuthFailureClass::ExpiredCredential.as_str(),
2176 "expired_credential"
2177 );
2178 assert_eq!(AuthFailureClass::RateLimited.as_str(), "rate_limited");
2179 assert_eq!(AuthFailureClass::PreAuthGate.as_str(), "pre_auth_gate");
2180 }
2181
2182 #[test]
2183 fn auth_failure_class_response_body_exact_strings() {
2184 assert_eq!(
2185 AuthFailureClass::MissingCredential.response_body(),
2186 "unauthorized: missing credential"
2187 );
2188 assert_eq!(
2189 AuthFailureClass::InvalidCredential.response_body(),
2190 "unauthorized: invalid credential"
2191 );
2192 assert_eq!(
2193 AuthFailureClass::ExpiredCredential.response_body(),
2194 "unauthorized: expired credential"
2195 );
2196 assert_eq!(
2197 AuthFailureClass::RateLimited.response_body(),
2198 "rate limited"
2199 );
2200 assert_eq!(
2201 AuthFailureClass::PreAuthGate.response_body(),
2202 "rate limited (pre-auth)"
2203 );
2204 }
2205
2206 #[test]
2207 fn auth_failure_class_bearer_error_exact_strings() {
2208 assert_eq!(
2209 AuthFailureClass::MissingCredential.bearer_error(),
2210 (
2211 "invalid_request",
2212 "missing bearer token or mTLS client certificate"
2213 )
2214 );
2215 assert_eq!(
2216 AuthFailureClass::InvalidCredential.bearer_error(),
2217 ("invalid_token", "token is invalid")
2218 );
2219 assert_eq!(
2220 AuthFailureClass::ExpiredCredential.bearer_error(),
2221 ("invalid_token", "token is expired")
2222 );
2223 assert_eq!(
2224 AuthFailureClass::RateLimited.bearer_error(),
2225 ("invalid_request", "too many failed authentication attempts")
2226 );
2227 assert_eq!(
2228 AuthFailureClass::PreAuthGate.bearer_error(),
2229 (
2230 "invalid_request",
2231 "too many unauthenticated requests from this source"
2232 )
2233 );
2234 }
2235
2236 #[test]
2245 fn auth_config_summary_bearer_true_when_keys_present() {
2246 let (_token, hash) = generate_api_key().unwrap();
2247 let cfg = AuthConfig::with_keys(vec![ApiKeyEntry::new("k", hash, "viewer")]);
2248 let s = cfg.summary();
2249 assert!(s.enabled, "summary.enabled must reflect AuthConfig.enabled");
2250 assert!(
2251 s.bearer,
2252 "summary.bearer must be true when api_keys is non-empty (kills `!` deletion at L615)"
2253 );
2254 assert!(!s.mtls, "summary.mtls must be false when mtls is None");
2255 assert!(!s.oauth, "summary.oauth must be false when oauth is None");
2256 assert_eq!(s.api_keys.len(), 1);
2257 assert_eq!(s.api_keys[0].name, "k");
2258 assert_eq!(s.api_keys[0].role, "viewer");
2259 }
2260
2261 #[test]
2262 fn auth_config_summary_bearer_false_when_no_keys() {
2263 let cfg = AuthConfig::with_keys(vec![]);
2264 let s = cfg.summary();
2265 assert!(
2266 !s.bearer,
2267 "summary.bearer must be false when api_keys is empty (kills `!` deletion at L615)"
2268 );
2269 assert!(s.api_keys.is_empty());
2270 }
2271}