1use std::{
10 collections::HashSet,
11 net::{IpAddr, SocketAddr},
12 num::NonZeroU32,
13 path::PathBuf,
14 sync::{
15 Arc, LazyLock, Mutex,
16 atomic::{AtomicU64, Ordering},
17 },
18 time::Duration,
19};
20
21use arc_swap::ArcSwap;
22use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
23use axum::{
24 body::Body,
25 extract::ConnectInfo,
26 http::{Request, header},
27 middleware::Next,
28 response::{IntoResponse, Response},
29};
30use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
31use secrecy::SecretString;
32use serde::Deserialize;
33use x509_parser::prelude::*;
34
35use crate::{bounded_limiter::BoundedKeyedLimiter, error::McpxError};
36
37#[derive(Clone)]
46#[non_exhaustive]
47pub struct AuthIdentity {
48 pub name: String,
50 pub role: String,
52 pub method: AuthMethod,
54 pub raw_token: Option<SecretString>,
60 pub sub: Option<String>,
63}
64
65impl std::fmt::Debug for AuthIdentity {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("AuthIdentity")
70 .field("name", &self.name)
71 .field("role", &self.role)
72 .field("method", &self.method)
73 .field(
74 "raw_token",
75 &if self.raw_token.is_some() {
76 "<redacted>"
77 } else {
78 "<none>"
79 },
80 )
81 .field(
82 "sub",
83 &if self.sub.is_some() {
84 "<redacted>"
85 } else {
86 "<none>"
87 },
88 )
89 .finish()
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95#[non_exhaustive]
96pub enum AuthMethod {
97 BearerToken,
99 MtlsCertificate,
101 OAuthJwt,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106enum AuthFailureClass {
107 MissingCredential,
108 InvalidCredential,
109 #[cfg_attr(not(feature = "oauth"), allow(dead_code))]
110 ExpiredCredential,
111 RateLimited,
113 PreAuthGate,
116}
117
118impl AuthFailureClass {
119 fn as_str(self) -> &'static str {
120 match self {
121 Self::MissingCredential => "missing_credential",
122 Self::InvalidCredential => "invalid_credential",
123 Self::ExpiredCredential => "expired_credential",
124 Self::RateLimited => "rate_limited",
125 Self::PreAuthGate => "pre_auth_gate",
126 }
127 }
128
129 fn bearer_error(self) -> (&'static str, &'static str) {
130 match self {
131 Self::MissingCredential => (
132 "invalid_request",
133 "missing bearer token or mTLS client certificate",
134 ),
135 Self::InvalidCredential => ("invalid_token", "token is invalid"),
136 Self::ExpiredCredential => ("invalid_token", "token is expired"),
137 Self::RateLimited => ("invalid_request", "too many failed authentication attempts"),
138 Self::PreAuthGate => (
139 "invalid_request",
140 "too many unauthenticated requests from this source",
141 ),
142 }
143 }
144
145 fn response_body(self) -> &'static str {
146 match self {
147 Self::MissingCredential => "unauthorized: missing credential",
148 Self::InvalidCredential => "unauthorized: invalid credential",
149 Self::ExpiredCredential => "unauthorized: expired credential",
150 Self::RateLimited => "rate limited",
151 Self::PreAuthGate => "rate limited (pre-auth)",
152 }
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
158#[non_exhaustive]
159pub struct AuthCountersSnapshot {
160 pub success_mtls: u64,
162 pub success_bearer: u64,
164 pub success_oauth_jwt: u64,
166 pub failure_missing_credential: u64,
168 pub failure_invalid_credential: u64,
170 pub failure_expired_credential: u64,
172 pub failure_rate_limited: u64,
174 pub failure_pre_auth_gate: u64,
177}
178
179#[derive(Debug, Default)]
181pub(crate) struct AuthCounters {
182 success_mtls: AtomicU64,
183 success_bearer: AtomicU64,
184 success_oauth_jwt: AtomicU64,
185 failure_missing_credential: AtomicU64,
186 failure_invalid_credential: AtomicU64,
187 failure_expired_credential: AtomicU64,
188 failure_rate_limited: AtomicU64,
189 failure_pre_auth_gate: AtomicU64,
190}
191
192impl AuthCounters {
193 fn record_success(&self, method: AuthMethod) {
194 match method {
195 AuthMethod::MtlsCertificate => {
196 self.success_mtls.fetch_add(1, Ordering::Relaxed);
197 }
198 AuthMethod::BearerToken => {
199 self.success_bearer.fetch_add(1, Ordering::Relaxed);
200 }
201 AuthMethod::OAuthJwt => {
202 self.success_oauth_jwt.fetch_add(1, Ordering::Relaxed);
203 }
204 }
205 }
206
207 fn record_failure(&self, class: AuthFailureClass) {
208 match class {
209 AuthFailureClass::MissingCredential => {
210 self.failure_missing_credential
211 .fetch_add(1, Ordering::Relaxed);
212 }
213 AuthFailureClass::InvalidCredential => {
214 self.failure_invalid_credential
215 .fetch_add(1, Ordering::Relaxed);
216 }
217 AuthFailureClass::ExpiredCredential => {
218 self.failure_expired_credential
219 .fetch_add(1, Ordering::Relaxed);
220 }
221 AuthFailureClass::RateLimited => {
222 self.failure_rate_limited.fetch_add(1, Ordering::Relaxed);
223 }
224 AuthFailureClass::PreAuthGate => {
225 self.failure_pre_auth_gate.fetch_add(1, Ordering::Relaxed);
226 }
227 }
228 }
229
230 fn snapshot(&self) -> AuthCountersSnapshot {
231 AuthCountersSnapshot {
232 success_mtls: self.success_mtls.load(Ordering::Relaxed),
233 success_bearer: self.success_bearer.load(Ordering::Relaxed),
234 success_oauth_jwt: self.success_oauth_jwt.load(Ordering::Relaxed),
235 failure_missing_credential: self.failure_missing_credential.load(Ordering::Relaxed),
236 failure_invalid_credential: self.failure_invalid_credential.load(Ordering::Relaxed),
237 failure_expired_credential: self.failure_expired_credential.load(Ordering::Relaxed),
238 failure_rate_limited: self.failure_rate_limited.load(Ordering::Relaxed),
239 failure_pre_auth_gate: self.failure_pre_auth_gate.load(Ordering::Relaxed),
240 }
241 }
242}
243
244#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
256#[non_exhaustive]
257pub struct RfcTimestamp(chrono::DateTime<chrono::FixedOffset>);
258
259impl RfcTimestamp {
260 pub fn parse(s: &str) -> Result<Self, chrono::ParseError> {
268 chrono::DateTime::parse_from_rfc3339(s).map(Self)
269 }
270
271 #[must_use]
273 pub fn as_datetime(&self) -> &chrono::DateTime<chrono::FixedOffset> {
274 &self.0
275 }
276
277 #[must_use]
279 pub fn into_inner(self) -> chrono::DateTime<chrono::FixedOffset> {
280 self.0
281 }
282}
283
284impl std::fmt::Display for RfcTimestamp {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 write!(f, "{}", self.0.to_rfc3339())
288 }
289}
290
291impl std::fmt::Debug for RfcTimestamp {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 write!(f, "{}", self.0.to_rfc3339())
298 }
299}
300
301impl<'de> Deserialize<'de> for RfcTimestamp {
302 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
303 where
304 D: serde::Deserializer<'de>,
305 {
306 let s = String::deserialize(deserializer)?;
310 Self::parse(&s).map_err(serde::de::Error::custom)
311 }
312}
313
314impl serde::Serialize for RfcTimestamp {
315 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
316 where
317 S: serde::Serializer,
318 {
319 serializer.serialize_str(&self.0.to_rfc3339())
320 }
321}
322
323impl From<chrono::DateTime<chrono::FixedOffset>> for RfcTimestamp {
324 fn from(value: chrono::DateTime<chrono::FixedOffset>) -> Self {
325 Self(value)
326 }
327}
328
329#[derive(Clone, Deserialize)]
336#[non_exhaustive]
337pub struct ApiKeyEntry {
338 pub name: String,
340 pub hash: String,
342 pub role: String,
344 pub expires_at: Option<RfcTimestamp>,
349}
350
351impl std::fmt::Debug for ApiKeyEntry {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 f.debug_struct("ApiKeyEntry")
356 .field("name", &self.name)
357 .field("hash", &"<redacted>")
358 .field("role", &self.role)
359 .field("expires_at", &self.expires_at)
360 .finish()
361 }
362}
363
364impl ApiKeyEntry {
365 #[must_use]
367 pub fn new(name: impl Into<String>, hash: impl Into<String>, role: impl Into<String>) -> Self {
368 Self {
369 name: name.into(),
370 hash: hash.into(),
371 role: role.into(),
372 expires_at: None,
373 }
374 }
375
376 #[must_use]
381 pub fn with_expiry(mut self, expires_at: RfcTimestamp) -> Self {
382 self.expires_at = Some(expires_at);
383 self
384 }
385
386 pub fn try_with_expiry(
394 mut self,
395 expires_at: impl AsRef<str>,
396 ) -> Result<Self, chrono::ParseError> {
397 self.expires_at = Some(RfcTimestamp::parse(expires_at.as_ref())?);
398 Ok(self)
399 }
400}
401
402#[derive(Debug, Clone, Deserialize)]
404#[allow(
405 clippy::struct_excessive_bools,
406 reason = "mTLS CRL behavior is intentionally configured as independent booleans"
407)]
408#[non_exhaustive]
409pub struct MtlsConfig {
410 pub ca_cert_path: PathBuf,
412 #[serde(default)]
415 pub required: bool,
416 #[serde(default = "default_mtls_role")]
419 pub default_role: String,
420 #[serde(default = "default_true")]
423 pub crl_enabled: bool,
424 #[serde(default, with = "humantime_serde::option")]
427 pub crl_refresh_interval: Option<Duration>,
428 #[serde(default = "default_crl_fetch_timeout", with = "humantime_serde")]
430 pub crl_fetch_timeout: Duration,
431 #[serde(default = "default_crl_stale_grace", with = "humantime_serde")]
434 pub crl_stale_grace: Duration,
435 #[serde(default)]
438 pub crl_deny_on_unavailable: bool,
439 #[serde(default)]
441 pub crl_end_entity_only: bool,
442 #[serde(default = "default_true")]
451 pub crl_allow_http: bool,
452 #[serde(default = "default_true")]
454 pub crl_enforce_expiration: bool,
455 #[serde(default = "default_crl_max_concurrent_fetches")]
461 pub crl_max_concurrent_fetches: usize,
462 #[serde(default = "default_crl_max_response_bytes")]
466 pub crl_max_response_bytes: u64,
467 #[serde(default = "default_crl_discovery_rate_per_min")]
483 pub crl_discovery_rate_per_min: u32,
484 #[serde(default = "default_crl_max_host_semaphores")]
491 pub crl_max_host_semaphores: usize,
492 #[serde(default = "default_crl_max_seen_urls")]
496 pub crl_max_seen_urls: usize,
497 #[serde(default = "default_crl_max_cache_entries")]
501 pub crl_max_cache_entries: usize,
502}
503
504fn default_mtls_role() -> String {
505 "viewer".into()
506}
507
508const fn default_true() -> bool {
509 true
510}
511
512const fn default_crl_fetch_timeout() -> Duration {
513 Duration::from_secs(30)
514}
515
516const fn default_crl_stale_grace() -> Duration {
517 Duration::from_hours(24)
518}
519
520const fn default_crl_max_concurrent_fetches() -> usize {
521 4
522}
523
524const fn default_crl_max_response_bytes() -> u64 {
525 5 * 1024 * 1024
526}
527
528const fn default_crl_discovery_rate_per_min() -> u32 {
529 60
530}
531
532const fn default_crl_max_host_semaphores() -> usize {
533 1024
534}
535
536const fn default_crl_max_seen_urls() -> usize {
537 4096
538}
539
540const fn default_crl_max_cache_entries() -> usize {
541 1024
542}
543
544#[derive(Debug, Clone, Deserialize)]
559#[non_exhaustive]
560pub struct RateLimitConfig {
561 #[serde(default = "default_max_attempts")]
564 pub max_attempts_per_minute: u32,
565 #[serde(default)]
573 pub pre_auth_max_per_minute: Option<u32>,
574 #[serde(default = "default_max_tracked_keys")]
579 pub max_tracked_keys: usize,
580 #[serde(default = "default_idle_eviction", with = "humantime_serde")]
583 pub idle_eviction: Duration,
584}
585
586impl Default for RateLimitConfig {
587 fn default() -> Self {
588 Self {
589 max_attempts_per_minute: default_max_attempts(),
590 pre_auth_max_per_minute: None,
591 max_tracked_keys: default_max_tracked_keys(),
592 idle_eviction: default_idle_eviction(),
593 }
594 }
595}
596
597impl RateLimitConfig {
598 #[must_use]
602 pub fn new(max_attempts_per_minute: u32) -> Self {
603 Self {
604 max_attempts_per_minute,
605 ..Self::default()
606 }
607 }
608
609 #[must_use]
612 pub fn with_pre_auth_max_per_minute(mut self, quota: u32) -> Self {
613 self.pre_auth_max_per_minute = Some(quota);
614 self
615 }
616
617 #[must_use]
619 pub fn with_max_tracked_keys(mut self, max: usize) -> Self {
620 self.max_tracked_keys = max;
621 self
622 }
623
624 #[must_use]
626 pub fn with_idle_eviction(mut self, idle: Duration) -> Self {
627 self.idle_eviction = idle;
628 self
629 }
630}
631
632fn default_max_attempts() -> u32 {
633 30
634}
635
636fn default_max_tracked_keys() -> usize {
637 10_000
638}
639
640fn default_idle_eviction() -> Duration {
641 Duration::from_mins(15)
642}
643
644#[derive(Debug, Clone, Default, Deserialize)]
646#[non_exhaustive]
647pub struct AuthConfig {
648 #[serde(default)]
650 pub enabled: bool,
651 #[serde(default)]
653 pub api_keys: Vec<ApiKeyEntry>,
654 pub mtls: Option<MtlsConfig>,
656 pub rate_limit: Option<RateLimitConfig>,
658 #[cfg(feature = "oauth")]
660 pub oauth: Option<crate::oauth::OAuthConfig>,
661}
662
663impl AuthConfig {
664 #[must_use]
666 pub fn with_keys(keys: Vec<ApiKeyEntry>) -> Self {
667 Self {
668 enabled: true,
669 api_keys: keys,
670 mtls: None,
671 rate_limit: None,
672 #[cfg(feature = "oauth")]
673 oauth: None,
674 }
675 }
676
677 #[must_use]
679 pub fn with_rate_limit(mut self, rate_limit: RateLimitConfig) -> Self {
680 self.rate_limit = Some(rate_limit);
681 self
682 }
683}
684
685#[derive(Debug, Clone, serde::Serialize)]
689#[non_exhaustive]
690pub struct ApiKeySummary {
691 pub name: String,
693 pub role: String,
695 pub expires_at: Option<RfcTimestamp>,
698}
699
700#[derive(Debug, Clone, serde::Serialize)]
702#[allow(
703 clippy::struct_excessive_bools,
704 reason = "this is a flat summary of independent auth-method booleans"
705)]
706#[non_exhaustive]
707pub struct AuthConfigSummary {
708 pub enabled: bool,
710 pub bearer: bool,
712 pub mtls: bool,
714 pub oauth: bool,
716 pub api_keys: Vec<ApiKeySummary>,
718}
719
720impl AuthConfig {
721 #[must_use]
723 pub fn summary(&self) -> AuthConfigSummary {
724 AuthConfigSummary {
725 enabled: self.enabled,
726 bearer: !self.api_keys.is_empty(),
727 mtls: self.mtls.is_some(),
728 #[cfg(feature = "oauth")]
729 oauth: self.oauth.is_some(),
730 #[cfg(not(feature = "oauth"))]
731 oauth: false,
732 api_keys: self
733 .api_keys
734 .iter()
735 .map(|k| ApiKeySummary {
736 name: k.name.clone(),
737 role: k.role.clone(),
738 expires_at: k.expires_at,
739 })
740 .collect(),
741 }
742 }
743}
744
745pub(crate) type KeyedLimiter = BoundedKeyedLimiter<IpAddr>;
748
749#[derive(Clone, Debug)]
759#[non_exhaustive]
760pub(crate) struct TlsConnInfo {
761 pub addr: SocketAddr,
763 pub identity: Option<AuthIdentity>,
766}
767
768impl TlsConnInfo {
769 #[must_use]
771 pub(crate) const fn new(addr: SocketAddr, identity: Option<AuthIdentity>) -> Self {
772 Self { addr, identity }
773 }
774}
775
776const DEFAULT_SEEN_IDENTITY_CAP: usize = 4096;
784
785pub(crate) struct SeenIdentitySet {
805 inner: Mutex<SeenInner>,
806}
807
808struct SeenInner {
809 set: HashSet<String>,
810 order: std::collections::VecDeque<String>,
815 cap: usize,
816}
817
818impl SeenIdentitySet {
819 #[must_use]
821 pub(crate) fn new() -> Self {
822 Self::with_cap(DEFAULT_SEEN_IDENTITY_CAP)
823 }
824
825 #[must_use]
828 pub(crate) fn with_cap(cap: usize) -> Self {
829 let cap = cap.max(1);
830 Self {
831 inner: Mutex::new(SeenInner {
832 set: HashSet::with_capacity(cap.min(64)),
833 order: std::collections::VecDeque::with_capacity(cap.min(64)),
834 cap,
835 }),
836 }
837 }
838
839 pub(crate) fn insert_is_first(&self, name: &str) -> bool {
846 let mut guard = self
852 .inner
853 .lock()
854 .unwrap_or_else(std::sync::PoisonError::into_inner);
855
856 if guard.set.contains(name) {
857 return false;
858 }
859 if guard.set.len() >= guard.cap
862 && let Some(evicted) = guard.order.pop_front()
863 {
864 guard.set.remove(&evicted);
865 }
866 let owned = name.to_owned();
867 guard.set.insert(owned.clone());
868 guard.order.push_back(owned);
869 true
870 }
871
872 #[cfg(test)]
874 pub(crate) fn len(&self) -> usize {
875 self.inner
876 .lock()
877 .unwrap_or_else(std::sync::PoisonError::into_inner)
878 .set
879 .len()
880 }
881}
882
883impl Default for SeenIdentitySet {
884 fn default() -> Self {
885 Self::new()
886 }
887}
888
889#[allow(
894 missing_debug_implementations,
895 reason = "contains governor RateLimiter and JwksCache without Debug impls"
896)]
897#[non_exhaustive]
898pub(crate) struct AuthState {
899 pub api_keys: ArcSwap<Vec<ApiKeyEntry>>,
901 pub rate_limiter: Option<Arc<KeyedLimiter>>,
903 pub pre_auth_limiter: Option<Arc<KeyedLimiter>>,
906 #[cfg(feature = "oauth")]
907 pub jwks_cache: Option<Arc<crate::oauth::JwksCache>>,
909 pub seen_identities: SeenIdentitySet,
914 pub counters: AuthCounters,
916}
917
918impl AuthState {
919 pub(crate) fn reload_keys(&self, keys: Vec<ApiKeyEntry>) {
925 let count = keys.len();
926 self.api_keys.store(Arc::new(keys));
927 tracing::info!(keys = count, "API keys reloaded");
928 }
929
930 #[must_use]
932 pub(crate) fn counters_snapshot(&self) -> AuthCountersSnapshot {
933 self.counters.snapshot()
934 }
935
936 #[must_use]
938 pub(crate) fn api_key_summaries(&self) -> Vec<ApiKeySummary> {
939 self.api_keys
940 .load()
941 .iter()
942 .map(|k| ApiKeySummary {
943 name: k.name.clone(),
944 role: k.role.clone(),
945 expires_at: k.expires_at,
946 })
947 .collect()
948 }
949
950 fn log_auth(&self, id: &AuthIdentity, method: &str) {
958 self.counters.record_success(id.method);
959 let first = self.seen_identities.insert_is_first(&id.name);
960 if first {
961 tracing::info!(name = %id.name, role = %id.role, "{method} authenticated");
962 } else {
963 tracing::debug!(name = %id.name, role = %id.role, "{method} authenticated");
964 }
965 }
966}
967
968const DEFAULT_AUTH_RATE: NonZeroU32 = NonZeroU32::new(30).unwrap();
971
972#[must_use]
974pub(crate) fn build_rate_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
975 let quota = governor::Quota::per_minute(
976 NonZeroU32::new(config.max_attempts_per_minute).unwrap_or(DEFAULT_AUTH_RATE),
977 );
978 Arc::new(BoundedKeyedLimiter::new(
979 quota,
980 config.max_tracked_keys,
981 config.idle_eviction,
982 ))
983}
984
985#[must_use]
992pub(crate) fn build_pre_auth_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
993 let resolved = config.pre_auth_max_per_minute.unwrap_or_else(|| {
994 config
995 .max_attempts_per_minute
996 .saturating_mul(PRE_AUTH_DEFAULT_MULTIPLIER)
997 });
998 let quota =
999 governor::Quota::per_minute(NonZeroU32::new(resolved).unwrap_or(DEFAULT_PRE_AUTH_RATE));
1000 Arc::new(BoundedKeyedLimiter::new(
1001 quota,
1002 config.max_tracked_keys,
1003 config.idle_eviction,
1004 ))
1005}
1006
1007const PRE_AUTH_DEFAULT_MULTIPLIER: u32 = 10;
1010
1011const DEFAULT_PRE_AUTH_RATE: NonZeroU32 = NonZeroU32::new(300).unwrap();
1015
1016#[must_use]
1021pub fn extract_mtls_identity(cert_der: &[u8], default_role: &str) -> Option<AuthIdentity> {
1022 let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
1023
1024 let cn = cert
1026 .subject()
1027 .iter_common_name()
1028 .next()
1029 .and_then(|attr| attr.as_str().ok())
1030 .map(String::from);
1031
1032 let name = cn.or_else(|| {
1034 cert.subject_alternative_name()
1035 .ok()
1036 .flatten()
1037 .and_then(|san| {
1038 #[allow(
1039 clippy::wildcard_enum_match_arm,
1040 reason = "x509-parser GeneralName is a large external enum; only DNSName is meaningful here"
1041 )]
1042 san.value.general_names.iter().find_map(|gn| match gn {
1043 GeneralName::DNSName(dns) => Some((*dns).to_owned()),
1044 _ => None,
1045 })
1046 })
1047 })?;
1048
1049 if !name
1051 .chars()
1052 .all(|c| c.is_alphanumeric() || matches!(c, '-' | '.' | '_' | '@'))
1053 {
1054 tracing::warn!(cn = %name, "mTLS identity rejected: invalid characters in CN/SAN");
1055 return None;
1056 }
1057
1058 Some(AuthIdentity {
1059 name,
1060 role: default_role.to_owned(),
1061 method: AuthMethod::MtlsCertificate,
1062 raw_token: None,
1063 sub: None,
1064 })
1065}
1066
1067fn extract_bearer(value: &str) -> Option<&str> {
1082 let (scheme, rest) = value.split_once(' ')?;
1083 if scheme.eq_ignore_ascii_case("Bearer") {
1084 let token = rest.trim_start_matches(' ');
1085 if token.is_empty() { None } else { Some(token) }
1086 } else {
1087 None
1088 }
1089}
1090
1091#[must_use]
1120pub fn verify_bearer_token(token: &str, keys: &[ApiKeyEntry]) -> Option<AuthIdentity> {
1121 use subtle::ConstantTimeEq as _;
1122
1123 let now = chrono::Utc::now();
1124 #[allow(
1125 clippy::expect_used,
1126 reason = "DUMMY_PHC_HASH is a static LazyLock built from a fixed Argon2id PHC string by construction; PasswordHash::new on it is infallible. See DUMMY_PHC_HASH definition."
1127 )]
1128 let dummy_hash = PasswordHash::new(&DUMMY_PHC_HASH)
1129 .expect("DUMMY_PHC_HASH is a valid Argon2id PHC string by construction");
1130
1131 let mut matched_index: usize = usize::MAX;
1132 let mut any_match: u8 = 0;
1133
1134 for (idx, key) in keys.iter().enumerate() {
1135 let expired = key.expires_at.is_some_and(|exp| exp.as_datetime() < &now);
1136
1137 let real_hash = PasswordHash::new(&key.hash);
1138 let verify_against = match (&real_hash, expired, any_match) {
1139 (Ok(h), false, 0) => h,
1140 _ => &dummy_hash,
1141 };
1142
1143 let slot_ok = u8::from(
1144 Argon2::default()
1145 .verify_password(token.as_bytes(), verify_against)
1146 .is_ok(),
1147 );
1148
1149 let real_match = slot_ok & u8::from(!expired) & u8::from(real_hash.is_ok());
1150 let first_real_match = real_match & (1 - any_match);
1151 if first_real_match.ct_eq(&1).into() {
1152 matched_index = idx;
1153 }
1154 any_match |= real_match;
1155 }
1156
1157 if any_match == 0 {
1158 return None;
1159 }
1160 let key = keys.get(matched_index)?;
1161 Some(AuthIdentity {
1162 name: key.name.clone(),
1163 role: key.role.clone(),
1164 method: AuthMethod::BearerToken,
1165 raw_token: None,
1166 sub: None,
1167 })
1168}
1169
1170static DUMMY_PHC_HASH: LazyLock<String> = LazyLock::new(|| {
1183 #[allow(
1185 clippy::expect_used,
1186 reason = "fixed 22-char base64 ('AAAA...') decodes to a valid 16-byte salt; SaltString::from_b64 is infallible on this literal"
1187 )]
1188 let salt = SaltString::from_b64("AAAAAAAAAAAAAAAAAAAAAA")
1189 .expect("fixed 16-byte base64 salt is well-formed");
1190 #[allow(
1191 clippy::expect_used,
1192 reason = "Argon2::default() with a fixed plaintext and a well-formed salt is infallible; only fails on bad params/salt"
1193 )]
1194 Argon2::default()
1195 .hash_password(b"rmcp-server-kit-dummy", &salt)
1196 .expect("Argon2 default params hash a fixed plaintext")
1197 .to_string()
1198});
1199
1200pub fn generate_api_key() -> Result<(String, String), McpxError> {
1210 let mut token_bytes = [0u8; 32];
1211 rand::fill(&mut token_bytes);
1212 let token = URL_SAFE_NO_PAD.encode(token_bytes);
1213
1214 let mut salt_bytes = [0u8; 16];
1216 rand::fill(&mut salt_bytes);
1217 let salt = SaltString::encode_b64(&salt_bytes)
1218 .map_err(|e| McpxError::Auth(format!("salt encoding failed: {e}")))?;
1219 let hash = Argon2::default()
1220 .hash_password(token.as_bytes(), &salt)
1221 .map_err(|e| McpxError::Auth(format!("argon2id hashing failed: {e}")))?
1222 .to_string();
1223
1224 Ok((token, hash))
1225}
1226
1227fn build_www_authenticate_value(
1228 advertise_resource_metadata: bool,
1229 failure: AuthFailureClass,
1230) -> String {
1231 let (error, error_description) = failure.bearer_error();
1232 if advertise_resource_metadata {
1233 return format!(
1234 "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\", error=\"{error}\", error_description=\"{error_description}\""
1235 );
1236 }
1237 format!("Bearer error=\"{error}\", error_description=\"{error_description}\"")
1238}
1239
1240fn auth_method_label(method: AuthMethod) -> &'static str {
1241 match method {
1242 AuthMethod::MtlsCertificate => "mTLS",
1243 AuthMethod::BearerToken => "bearer token",
1244 AuthMethod::OAuthJwt => "OAuth JWT",
1245 }
1246}
1247
1248#[cfg_attr(not(feature = "oauth"), allow(unused_variables))]
1249fn unauthorized_response(state: &AuthState, failure_class: AuthFailureClass) -> Response {
1250 #[cfg(feature = "oauth")]
1251 let advertise_resource_metadata = state.jwks_cache.is_some();
1252 #[cfg(not(feature = "oauth"))]
1253 let advertise_resource_metadata = false;
1254
1255 let challenge = build_www_authenticate_value(advertise_resource_metadata, failure_class);
1256 (
1257 axum::http::StatusCode::UNAUTHORIZED,
1258 [(header::WWW_AUTHENTICATE, challenge)],
1259 failure_class.response_body(),
1260 )
1261 .into_response()
1262}
1263
1264async fn authenticate_bearer_identity(
1265 state: &AuthState,
1266 token: &str,
1267) -> Result<AuthIdentity, AuthFailureClass> {
1268 let mut failure_class = AuthFailureClass::MissingCredential;
1269
1270 #[cfg(feature = "oauth")]
1271 if let Some(ref cache) = state.jwks_cache
1272 && crate::oauth::looks_like_jwt(token)
1273 {
1274 match cache.validate_token_with_reason(token).await {
1275 Ok(mut id) => {
1276 id.raw_token = Some(SecretString::from(token.to_owned()));
1277 return Ok(id);
1278 }
1279 Err(crate::oauth::JwtValidationFailure::Expired) => {
1280 failure_class = AuthFailureClass::ExpiredCredential;
1281 }
1282 Err(crate::oauth::JwtValidationFailure::Invalid) => {
1283 failure_class = AuthFailureClass::InvalidCredential;
1284 }
1285 }
1286 }
1287
1288 let token = token.to_owned();
1289 let keys = state.api_keys.load_full(); let identity = tokio::task::spawn_blocking(move || verify_bearer_token(&token, &keys))
1293 .await
1294 .ok()
1295 .flatten();
1296
1297 if let Some(id) = identity {
1298 return Ok(id);
1299 }
1300
1301 if failure_class == AuthFailureClass::MissingCredential {
1302 failure_class = AuthFailureClass::InvalidCredential;
1303 }
1304
1305 Err(failure_class)
1306}
1307
1308fn pre_auth_gate(state: &AuthState, peer_addr: Option<SocketAddr>) -> Option<Response> {
1319 let limiter = state.pre_auth_limiter.as_ref()?;
1320 let addr = peer_addr?;
1321 if limiter.check_key(&addr.ip()).is_ok() {
1322 return None;
1323 }
1324 state.counters.record_failure(AuthFailureClass::PreAuthGate);
1325 tracing::warn!(
1326 ip = %addr.ip(),
1327 "auth rate limited by pre-auth gate (request rejected before credential verification)"
1328 );
1329 Some(
1330 McpxError::RateLimited("too many unauthenticated requests from this source".into())
1331 .into_response(),
1332 )
1333}
1334
1335pub(crate) async fn auth_middleware(
1344 state: Arc<AuthState>,
1345 req: Request<Body>,
1346 next: Next,
1347) -> Response {
1348 let tls_info = req.extensions().get::<ConnectInfo<TlsConnInfo>>().cloned();
1353 let peer_addr = req
1354 .extensions()
1355 .get::<ConnectInfo<SocketAddr>>()
1356 .map(|ci| ci.0)
1357 .or_else(|| tls_info.as_ref().map(|ci| ci.0.addr));
1358
1359 if let Some(id) = tls_info.and_then(|ci| ci.0.identity) {
1366 state.log_auth(&id, "mTLS");
1367 let mut req = req;
1368 req.extensions_mut().insert(id);
1369 return next.run(req).await;
1370 }
1371
1372 if let Some(blocked) = pre_auth_gate(&state, peer_addr) {
1376 return blocked;
1377 }
1378
1379 let failure_class = if let Some(value) = req.headers().get(header::AUTHORIZATION) {
1380 match value.to_str().ok().and_then(extract_bearer) {
1381 Some(token) => match authenticate_bearer_identity(&state, token).await {
1382 Ok(id) => {
1383 state.log_auth(&id, auth_method_label(id.method));
1384 let mut req = req;
1385 req.extensions_mut().insert(id);
1386 return next.run(req).await;
1387 }
1388 Err(class) => class,
1389 },
1390 None => AuthFailureClass::InvalidCredential,
1391 }
1392 } else {
1393 AuthFailureClass::MissingCredential
1394 };
1395
1396 tracing::warn!(failure_class = %failure_class.as_str(), "auth failed");
1397
1398 if let (Some(limiter), Some(addr)) = (&state.rate_limiter, peer_addr)
1401 && limiter.check_key(&addr.ip()).is_err()
1402 {
1403 state.counters.record_failure(AuthFailureClass::RateLimited);
1404 tracing::warn!(ip = %addr.ip(), "auth rate limited after repeated failures");
1405 return McpxError::RateLimited("too many failed authentication attempts".into())
1406 .into_response();
1407 }
1408
1409 state.counters.record_failure(failure_class);
1410 unauthorized_response(&state, failure_class)
1411}
1412
1413#[cfg(test)]
1414mod tests {
1415 use super::*;
1416
1417 #[test]
1418 fn generate_and_verify_api_key() {
1419 let (token, hash) = generate_api_key().unwrap();
1420
1421 assert_eq!(token.len(), 43);
1423
1424 assert!(hash.starts_with("$argon2id$"));
1426
1427 let keys = vec![ApiKeyEntry {
1429 name: "test".into(),
1430 hash,
1431 role: "viewer".into(),
1432 expires_at: None,
1433 }];
1434 let id = verify_bearer_token(&token, &keys);
1435 assert!(id.is_some());
1436 let id = id.unwrap();
1437 assert_eq!(id.name, "test");
1438 assert_eq!(id.role, "viewer");
1439 assert_eq!(id.method, AuthMethod::BearerToken);
1440 }
1441
1442 #[test]
1443 fn wrong_token_rejected() {
1444 let (_token, hash) = generate_api_key().unwrap();
1445 let keys = vec![ApiKeyEntry {
1446 name: "test".into(),
1447 hash,
1448 role: "viewer".into(),
1449 expires_at: None,
1450 }];
1451 assert!(verify_bearer_token("wrong-token", &keys).is_none());
1452 }
1453
1454 #[test]
1455 fn expired_key_rejected() {
1456 let (token, hash) = generate_api_key().unwrap();
1457 let keys = vec![ApiKeyEntry {
1458 name: "test".into(),
1459 hash,
1460 role: "viewer".into(),
1461 expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1462 }];
1463 assert!(verify_bearer_token(&token, &keys).is_none());
1464 }
1465
1466 #[test]
1467 fn match_in_last_slot_still_authenticates() {
1468 let (token, hash) = generate_api_key().unwrap();
1469 let (_other_token, other_hash) = generate_api_key().unwrap();
1470 let keys = vec![
1471 ApiKeyEntry {
1472 name: "first".into(),
1473 hash: other_hash.clone(),
1474 role: "viewer".into(),
1475 expires_at: None,
1476 },
1477 ApiKeyEntry {
1478 name: "second".into(),
1479 hash: other_hash,
1480 role: "viewer".into(),
1481 expires_at: None,
1482 },
1483 ApiKeyEntry {
1484 name: "match".into(),
1485 hash,
1486 role: "ops".into(),
1487 expires_at: None,
1488 },
1489 ];
1490 let id = verify_bearer_token(&token, &keys).expect("last-slot match must authenticate");
1491 assert_eq!(id.name, "match");
1492 assert_eq!(id.role, "ops");
1493 }
1494
1495 #[test]
1496 fn expired_slot_before_valid_match_does_not_short_circuit() {
1497 let (token, hash) = generate_api_key().unwrap();
1498 let (_, other_hash) = generate_api_key().unwrap();
1499 let keys = vec![
1500 ApiKeyEntry {
1501 name: "expired".into(),
1502 hash: other_hash,
1503 role: "viewer".into(),
1504 expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1505 },
1506 ApiKeyEntry {
1507 name: "valid".into(),
1508 hash,
1509 role: "ops".into(),
1510 expires_at: None,
1511 },
1512 ];
1513 let id = verify_bearer_token(&token, &keys)
1514 .expect("valid slot following an expired slot must authenticate");
1515 assert_eq!(id.name, "valid");
1516 }
1517
1518 #[test]
1519 fn malformed_hash_slot_does_not_short_circuit() {
1520 let (token, hash) = generate_api_key().unwrap();
1521 let keys = vec![
1522 ApiKeyEntry {
1523 name: "broken".into(),
1524 hash: "this-is-not-a-phc-string".into(),
1525 role: "viewer".into(),
1526 expires_at: None,
1527 },
1528 ApiKeyEntry {
1529 name: "valid".into(),
1530 hash,
1531 role: "ops".into(),
1532 expires_at: None,
1533 },
1534 ];
1535 let id = verify_bearer_token(&token, &keys)
1536 .expect("valid slot following a malformed-hash slot must authenticate");
1537 assert_eq!(id.name, "valid");
1538 }
1539
1540 #[test]
1551 fn rfc_timestamp_parse_rejects_malformed() {
1552 for bad in [
1553 "not-a-date",
1554 "",
1555 "2025-13-01T00:00:00Z", "2025-01-32T00:00:00Z", "2025-01-01T00:00:00", "01/01/2025", "2025-01-01T25:00:00Z", ] {
1561 assert!(
1562 RfcTimestamp::parse(bad).is_err(),
1563 "RfcTimestamp::parse must reject {bad:?}"
1564 );
1565 }
1566 }
1567
1568 #[test]
1569 fn rfc_timestamp_parse_accepts_valid() {
1570 for good in [
1571 "2025-01-01T00:00:00Z",
1572 "2025-01-01T00:00:00+00:00",
1573 "2025-12-31T23:59:59-08:00",
1574 "2099-01-01T00:00:00.123456789Z",
1575 ] {
1576 assert!(
1577 RfcTimestamp::parse(good).is_ok(),
1578 "RfcTimestamp::parse must accept {good:?}"
1579 );
1580 }
1581 }
1582
1583 #[test]
1584 fn api_key_entry_deserialize_rejects_malformed_expires_at() {
1585 let toml = r#"
1590 name = "bad-key"
1591 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1592 role = "viewer"
1593 expires_at = "not-a-date"
1594 "#;
1595 let result: Result<ApiKeyEntry, _> = toml::from_str(toml);
1596 assert!(
1597 result.is_err(),
1598 "deserialization must reject malformed expires_at"
1599 );
1600 }
1601
1602 #[test]
1603 fn api_key_entry_deserialize_accepts_valid_expires_at() {
1604 let toml = r#"
1605 name = "good-key"
1606 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1607 role = "viewer"
1608 expires_at = "2099-01-01T00:00:00Z"
1609 "#;
1610 let entry: ApiKeyEntry = toml::from_str(toml).expect("valid RFC 3339 must deserialize");
1611 assert!(entry.expires_at.is_some());
1612 }
1613
1614 #[test]
1615 fn api_key_entry_deserialize_accepts_missing_expires_at() {
1616 let toml = r#"
1619 name = "eternal-key"
1620 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1621 role = "viewer"
1622 "#;
1623 let entry: ApiKeyEntry = toml::from_str(toml).expect("missing expires_at must deserialize");
1624 assert!(entry.expires_at.is_none());
1625 }
1626
1627 #[test]
1628 fn try_with_expiry_rejects_malformed() {
1629 let entry = ApiKeyEntry::new("k", "hash", "viewer");
1630 assert!(entry.try_with_expiry("not-a-date").is_err());
1631 }
1632
1633 #[test]
1634 fn try_with_expiry_accepts_valid() {
1635 let entry = ApiKeyEntry::new("k", "hash", "viewer")
1636 .try_with_expiry("2099-01-01T00:00:00Z")
1637 .expect("valid RFC 3339 must be accepted");
1638 assert!(entry.expires_at.is_some());
1639 }
1640
1641 #[test]
1642 fn api_key_summary_serializes_expires_at_as_rfc3339() {
1643 let summary = ApiKeySummary {
1648 name: "k".into(),
1649 role: "viewer".into(),
1650 expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
1651 };
1652 let json = serde_json::to_string(&summary).unwrap();
1653 assert!(
1654 json.contains(r#""expires_at":"2030-01-01T00:00:00+00:00""#),
1655 "wire format regressed: {json}"
1656 );
1657 }
1658
1659 #[test]
1660 fn future_expiry_accepted() {
1661 let (token, hash) = generate_api_key().unwrap();
1662 let keys = vec![ApiKeyEntry {
1663 name: "test".into(),
1664 hash,
1665 role: "viewer".into(),
1666 expires_at: Some(RfcTimestamp::parse("2099-01-01T00:00:00Z").unwrap()),
1667 }];
1668 assert!(verify_bearer_token(&token, &keys).is_some());
1669 }
1670
1671 #[test]
1672 fn multiple_keys_first_match_wins() {
1673 let (token, hash) = generate_api_key().unwrap();
1674 let keys = vec![
1675 ApiKeyEntry {
1676 name: "wrong".into(),
1677 hash: "$argon2id$v=19$m=19456,t=2,p=1$invalid$invalid".into(),
1678 role: "ops".into(),
1679 expires_at: None,
1680 },
1681 ApiKeyEntry {
1682 name: "correct".into(),
1683 hash,
1684 role: "deploy".into(),
1685 expires_at: None,
1686 },
1687 ];
1688 let id = verify_bearer_token(&token, &keys).unwrap();
1689 assert_eq!(id.name, "correct");
1690 assert_eq!(id.role, "deploy");
1691 }
1692
1693 #[test]
1694 fn rate_limiter_allows_within_quota() {
1695 let config = RateLimitConfig {
1696 max_attempts_per_minute: 5,
1697 pre_auth_max_per_minute: None,
1698 max_tracked_keys: default_max_tracked_keys(),
1699 idle_eviction: default_idle_eviction(),
1700 };
1701 let limiter = build_rate_limiter(&config);
1702 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1703
1704 for _ in 0..5 {
1706 assert!(limiter.check_key(&ip).is_ok());
1707 }
1708 assert!(limiter.check_key(&ip).is_err());
1710 }
1711
1712 #[test]
1713 fn rate_limiter_separate_ips() {
1714 let config = RateLimitConfig {
1715 max_attempts_per_minute: 2,
1716 pre_auth_max_per_minute: None,
1717 max_tracked_keys: default_max_tracked_keys(),
1718 idle_eviction: default_idle_eviction(),
1719 };
1720 let limiter = build_rate_limiter(&config);
1721 let ip1: IpAddr = "10.0.0.1".parse().unwrap();
1722 let ip2: IpAddr = "10.0.0.2".parse().unwrap();
1723
1724 assert!(limiter.check_key(&ip1).is_ok());
1726 assert!(limiter.check_key(&ip1).is_ok());
1727 assert!(limiter.check_key(&ip1).is_err());
1728
1729 assert!(limiter.check_key(&ip2).is_ok());
1731 }
1732
1733 #[test]
1734 fn extract_mtls_identity_from_cn() {
1735 let mut params = rcgen::CertificateParams::new(vec!["test-client.local".into()]).unwrap();
1737 params.distinguished_name = rcgen::DistinguishedName::new();
1738 params
1739 .distinguished_name
1740 .push(rcgen::DnType::CommonName, "test-client");
1741 let cert = params
1742 .self_signed(&rcgen::KeyPair::generate().unwrap())
1743 .unwrap();
1744 let der = cert.der();
1745
1746 let id = extract_mtls_identity(der, "ops").unwrap();
1747 assert_eq!(id.name, "test-client");
1748 assert_eq!(id.role, "ops");
1749 assert_eq!(id.method, AuthMethod::MtlsCertificate);
1750 }
1751
1752 #[test]
1753 fn extract_mtls_identity_falls_back_to_san() {
1754 let mut params =
1756 rcgen::CertificateParams::new(vec!["san-only.example.com".into()]).unwrap();
1757 params.distinguished_name = rcgen::DistinguishedName::new();
1758 let cert = params
1760 .self_signed(&rcgen::KeyPair::generate().unwrap())
1761 .unwrap();
1762 let der = cert.der();
1763
1764 let id = extract_mtls_identity(der, "viewer").unwrap();
1765 assert_eq!(id.name, "san-only.example.com");
1766 assert_eq!(id.role, "viewer");
1767 }
1768
1769 #[test]
1770 fn extract_mtls_identity_invalid_der() {
1771 assert!(extract_mtls_identity(b"not-a-cert", "viewer").is_none());
1772 }
1773
1774 use axum::{
1777 body::Body,
1778 http::{Request, StatusCode},
1779 };
1780 use tower::ServiceExt as _;
1781
1782 fn auth_router(state: Arc<AuthState>) -> axum::Router {
1783 axum::Router::new()
1784 .route("/mcp", axum::routing::post(|| async { "ok" }))
1785 .layer(axum::middleware::from_fn(move |req, next| {
1786 let s = Arc::clone(&state);
1787 auth_middleware(s, req, next)
1788 }))
1789 }
1790
1791 fn test_auth_state(keys: Vec<ApiKeyEntry>) -> Arc<AuthState> {
1792 Arc::new(AuthState {
1793 api_keys: ArcSwap::new(Arc::new(keys)),
1794 rate_limiter: None,
1795 pre_auth_limiter: None,
1796 #[cfg(feature = "oauth")]
1797 jwks_cache: None,
1798 seen_identities: SeenIdentitySet::new(),
1799 counters: AuthCounters::default(),
1800 })
1801 }
1802
1803 #[tokio::test]
1804 async fn middleware_rejects_no_credentials() {
1805 let state = test_auth_state(vec![]);
1806 let app = auth_router(Arc::clone(&state));
1807 let req = Request::builder()
1808 .method(axum::http::Method::POST)
1809 .uri("/mcp")
1810 .body(Body::empty())
1811 .unwrap();
1812 let resp = app.oneshot(req).await.unwrap();
1813 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1814 let challenge = resp
1815 .headers()
1816 .get(header::WWW_AUTHENTICATE)
1817 .unwrap()
1818 .to_str()
1819 .unwrap();
1820 assert!(challenge.contains("error=\"invalid_request\""));
1821
1822 let counters = state.counters_snapshot();
1823 assert_eq!(counters.failure_missing_credential, 1);
1824 }
1825
1826 #[tokio::test]
1827 async fn middleware_accepts_valid_bearer() {
1828 let (token, hash) = generate_api_key().unwrap();
1829 let keys = vec![ApiKeyEntry {
1830 name: "test-key".into(),
1831 hash,
1832 role: "ops".into(),
1833 expires_at: None,
1834 }];
1835 let state = test_auth_state(keys);
1836 let app = auth_router(Arc::clone(&state));
1837 let req = Request::builder()
1838 .method(axum::http::Method::POST)
1839 .uri("/mcp")
1840 .header("authorization", format!("Bearer {token}"))
1841 .body(Body::empty())
1842 .unwrap();
1843 let resp = app.oneshot(req).await.unwrap();
1844 assert_eq!(resp.status(), StatusCode::OK);
1845
1846 let counters = state.counters_snapshot();
1847 assert_eq!(counters.success_bearer, 1);
1848 }
1849
1850 #[tokio::test]
1851 async fn middleware_rejects_wrong_bearer() {
1852 let (_token, hash) = generate_api_key().unwrap();
1853 let keys = vec![ApiKeyEntry {
1854 name: "test-key".into(),
1855 hash,
1856 role: "ops".into(),
1857 expires_at: None,
1858 }];
1859 let state = test_auth_state(keys);
1860 let app = auth_router(Arc::clone(&state));
1861 let req = Request::builder()
1862 .method(axum::http::Method::POST)
1863 .uri("/mcp")
1864 .header("authorization", "Bearer wrong-token-here")
1865 .body(Body::empty())
1866 .unwrap();
1867 let resp = app.oneshot(req).await.unwrap();
1868 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1869 let challenge = resp
1870 .headers()
1871 .get(header::WWW_AUTHENTICATE)
1872 .unwrap()
1873 .to_str()
1874 .unwrap();
1875 assert!(challenge.contains("error=\"invalid_token\""));
1876
1877 let counters = state.counters_snapshot();
1878 assert_eq!(counters.failure_invalid_credential, 1);
1879 }
1880
1881 #[tokio::test]
1882 async fn middleware_rate_limits() {
1883 let state = Arc::new(AuthState {
1884 api_keys: ArcSwap::new(Arc::new(vec![])),
1885 rate_limiter: Some(build_rate_limiter(&RateLimitConfig {
1886 max_attempts_per_minute: 1,
1887 pre_auth_max_per_minute: None,
1888 max_tracked_keys: default_max_tracked_keys(),
1889 idle_eviction: default_idle_eviction(),
1890 })),
1891 pre_auth_limiter: None,
1892 #[cfg(feature = "oauth")]
1893 jwks_cache: None,
1894 seen_identities: SeenIdentitySet::new(),
1895 counters: AuthCounters::default(),
1896 });
1897 let app = auth_router(state);
1898
1899 let req = Request::builder()
1901 .method(axum::http::Method::POST)
1902 .uri("/mcp")
1903 .body(Body::empty())
1904 .unwrap();
1905 let resp = app.clone().oneshot(req).await.unwrap();
1906 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1907
1908 }
1913
1914 #[test]
1920 fn rate_limit_semantics_failed_only() {
1921 let config = RateLimitConfig {
1922 max_attempts_per_minute: 3,
1923 pre_auth_max_per_minute: None,
1924 max_tracked_keys: default_max_tracked_keys(),
1925 idle_eviction: default_idle_eviction(),
1926 };
1927 let limiter = build_rate_limiter(&config);
1928 let ip: IpAddr = "192.168.1.100".parse().unwrap();
1929
1930 assert!(
1932 limiter.check_key(&ip).is_ok(),
1933 "failure 1 should be allowed"
1934 );
1935 assert!(
1936 limiter.check_key(&ip).is_ok(),
1937 "failure 2 should be allowed"
1938 );
1939 assert!(
1940 limiter.check_key(&ip).is_ok(),
1941 "failure 3 should be allowed"
1942 );
1943 assert!(
1944 limiter.check_key(&ip).is_err(),
1945 "failure 4 should be blocked"
1946 );
1947
1948 }
1957
1958 #[test]
1963 fn pre_auth_default_multiplier_is_10x() {
1964 let config = RateLimitConfig {
1965 max_attempts_per_minute: 5,
1966 pre_auth_max_per_minute: None,
1967 max_tracked_keys: default_max_tracked_keys(),
1968 idle_eviction: default_idle_eviction(),
1969 };
1970 let limiter = build_pre_auth_limiter(&config);
1971 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1972
1973 for i in 0..50 {
1975 assert!(
1976 limiter.check_key(&ip).is_ok(),
1977 "pre-auth attempt {i} (of expected 50) should be allowed under default 10x multiplier"
1978 );
1979 }
1980 assert!(
1982 limiter.check_key(&ip).is_err(),
1983 "pre-auth attempt 51 should be blocked (quota is 50, not unbounded)"
1984 );
1985 }
1986
1987 #[test]
1990 fn pre_auth_explicit_override_wins() {
1991 let config = RateLimitConfig {
1992 max_attempts_per_minute: 100, pre_auth_max_per_minute: Some(2), max_tracked_keys: default_max_tracked_keys(),
1995 idle_eviction: default_idle_eviction(),
1996 };
1997 let limiter = build_pre_auth_limiter(&config);
1998 let ip: IpAddr = "10.0.0.2".parse().unwrap();
1999
2000 assert!(limiter.check_key(&ip).is_ok(), "attempt 1 allowed");
2001 assert!(limiter.check_key(&ip).is_ok(), "attempt 2 allowed");
2002 assert!(
2003 limiter.check_key(&ip).is_err(),
2004 "attempt 3 must be blocked (explicit override of 2 wins over 10x default of 1000)"
2005 );
2006 }
2007
2008 #[tokio::test]
2014 async fn pre_auth_gate_blocks_before_argon2_verification() {
2015 let (_token, hash) = generate_api_key().unwrap();
2016 let keys = vec![ApiKeyEntry {
2017 name: "test-key".into(),
2018 hash,
2019 role: "ops".into(),
2020 expires_at: None,
2021 }];
2022 let config = RateLimitConfig {
2023 max_attempts_per_minute: 100,
2024 pre_auth_max_per_minute: Some(1),
2025 max_tracked_keys: default_max_tracked_keys(),
2026 idle_eviction: default_idle_eviction(),
2027 };
2028 let state = Arc::new(AuthState {
2029 api_keys: ArcSwap::new(Arc::new(keys)),
2030 rate_limiter: None,
2031 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
2032 #[cfg(feature = "oauth")]
2033 jwks_cache: None,
2034 seen_identities: SeenIdentitySet::new(),
2035 counters: AuthCounters::default(),
2036 });
2037 let app = auth_router(Arc::clone(&state));
2038 let peer: SocketAddr = "10.0.0.10:54321".parse().unwrap();
2039
2040 let mut req1 = Request::builder()
2043 .method(axum::http::Method::POST)
2044 .uri("/mcp")
2045 .header("authorization", "Bearer obviously-not-a-real-token")
2046 .body(Body::empty())
2047 .unwrap();
2048 req1.extensions_mut().insert(ConnectInfo(peer));
2049 let resp1 = app.clone().oneshot(req1).await.unwrap();
2050 assert_eq!(
2051 resp1.status(),
2052 StatusCode::UNAUTHORIZED,
2053 "first attempt: gate has quota, falls through to bearer auth which fails with 401"
2054 );
2055
2056 let mut req2 = Request::builder()
2059 .method(axum::http::Method::POST)
2060 .uri("/mcp")
2061 .header("authorization", "Bearer also-not-a-real-token")
2062 .body(Body::empty())
2063 .unwrap();
2064 req2.extensions_mut().insert(ConnectInfo(peer));
2065 let resp2 = app.oneshot(req2).await.unwrap();
2066 assert_eq!(
2067 resp2.status(),
2068 StatusCode::TOO_MANY_REQUESTS,
2069 "second attempt from same IP: pre-auth gate must reject with 429"
2070 );
2071
2072 let counters = state.counters_snapshot();
2073 assert_eq!(
2074 counters.failure_pre_auth_gate, 1,
2075 "exactly one request must have been rejected by the pre-auth gate"
2076 );
2077 assert_eq!(
2081 counters.failure_invalid_credential, 1,
2082 "bearer verification must run exactly once (only the un-gated first request)"
2083 );
2084 }
2085
2086 #[tokio::test]
2093 async fn pre_auth_gate_does_not_throttle_mtls() {
2094 let config = RateLimitConfig {
2095 max_attempts_per_minute: 100,
2096 pre_auth_max_per_minute: Some(1), max_tracked_keys: default_max_tracked_keys(),
2098 idle_eviction: default_idle_eviction(),
2099 };
2100 let state = Arc::new(AuthState {
2101 api_keys: ArcSwap::new(Arc::new(vec![])),
2102 rate_limiter: None,
2103 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
2104 #[cfg(feature = "oauth")]
2105 jwks_cache: None,
2106 seen_identities: SeenIdentitySet::new(),
2107 counters: AuthCounters::default(),
2108 });
2109 let app = auth_router(Arc::clone(&state));
2110 let peer: SocketAddr = "10.0.0.20:54321".parse().unwrap();
2111 let identity = AuthIdentity {
2112 name: "cn=test-client".into(),
2113 role: "viewer".into(),
2114 method: AuthMethod::MtlsCertificate,
2115 raw_token: None,
2116 sub: None,
2117 };
2118 let tls_info = TlsConnInfo::new(peer, Some(identity));
2119
2120 for i in 0..3 {
2121 let mut req = Request::builder()
2122 .method(axum::http::Method::POST)
2123 .uri("/mcp")
2124 .body(Body::empty())
2125 .unwrap();
2126 req.extensions_mut().insert(ConnectInfo(tls_info.clone()));
2127 let resp = app.clone().oneshot(req).await.unwrap();
2128 assert_eq!(
2129 resp.status(),
2130 StatusCode::OK,
2131 "mTLS request {i} must succeed: pre-auth gate must not apply to mTLS callers"
2132 );
2133 }
2134
2135 let counters = state.counters_snapshot();
2136 assert_eq!(
2137 counters.failure_pre_auth_gate, 0,
2138 "pre-auth gate counter must remain at zero: mTLS bypasses the gate"
2139 );
2140 assert_eq!(
2141 counters.success_mtls, 3,
2142 "all three mTLS requests must have been counted as successful"
2143 );
2144 }
2145
2146 #[test]
2151 fn extract_bearer_accepts_canonical_case() {
2152 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2153 }
2154
2155 #[test]
2156 fn extract_bearer_is_case_insensitive_per_rfc7235() {
2157 for header in &[
2161 "bearer abc123",
2162 "BEARER abc123",
2163 "BeArEr abc123",
2164 "bEaReR abc123",
2165 ] {
2166 assert_eq!(
2167 extract_bearer(header),
2168 Some("abc123"),
2169 "header {header:?} must parse as a Bearer token (RFC 7235 §2.1)"
2170 );
2171 }
2172 }
2173
2174 #[test]
2175 fn extract_bearer_rejects_other_schemes() {
2176 assert_eq!(extract_bearer("Basic dXNlcjpwYXNz"), None);
2177 assert_eq!(extract_bearer("Digest username=\"x\""), None);
2178 assert_eq!(extract_bearer("Token abc123"), None);
2179 }
2180
2181 #[test]
2182 fn extract_bearer_rejects_malformed() {
2183 assert_eq!(extract_bearer(""), None);
2185 assert_eq!(extract_bearer("Bearer"), None);
2186 assert_eq!(extract_bearer("Bearer "), None);
2187 assert_eq!(extract_bearer("Bearer "), None);
2188 }
2189
2190 #[test]
2191 fn extract_bearer_tolerates_extra_separator_whitespace() {
2192 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2194 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2195 }
2196
2197 #[test]
2203 fn auth_identity_debug_redacts_raw_token() {
2204 let id = AuthIdentity {
2205 name: "alice".into(),
2206 role: "admin".into(),
2207 method: AuthMethod::OAuthJwt,
2208 raw_token: Some(SecretString::from("super-secret-jwt-payload-xyz")),
2209 sub: Some("keycloak-uuid-2f3c8b".into()),
2210 };
2211 let dbg = format!("{id:?}");
2212
2213 assert!(dbg.contains("alice"), "name should be visible: {dbg}");
2215 assert!(dbg.contains("admin"), "role should be visible: {dbg}");
2216 assert!(dbg.contains("OAuthJwt"), "method should be visible: {dbg}");
2217
2218 assert!(
2220 !dbg.contains("super-secret-jwt-payload-xyz"),
2221 "raw_token must be redacted in Debug output: {dbg}"
2222 );
2223 assert!(
2224 !dbg.contains("keycloak-uuid-2f3c8b"),
2225 "sub must be redacted in Debug output: {dbg}"
2226 );
2227 assert!(
2228 dbg.contains("<redacted>"),
2229 "redaction marker missing: {dbg}"
2230 );
2231 }
2232
2233 #[test]
2234 fn auth_identity_debug_marks_absent_secrets() {
2235 let id = AuthIdentity {
2238 name: "viewer-key".into(),
2239 role: "viewer".into(),
2240 method: AuthMethod::BearerToken,
2241 raw_token: None,
2242 sub: None,
2243 };
2244 let dbg = format!("{id:?}");
2245 assert!(
2246 dbg.contains("<none>"),
2247 "absent secrets should be marked: {dbg}"
2248 );
2249 assert!(
2250 !dbg.contains("<redacted>"),
2251 "no <redacted> marker when secrets are absent: {dbg}"
2252 );
2253 }
2254
2255 #[test]
2256 fn api_key_entry_debug_redacts_hash() {
2257 let entry = ApiKeyEntry {
2258 name: "viewer-key".into(),
2259 hash: "$argon2id$v=19$m=19456,t=2,p=1$c2FsdHNhbHQ$h4sh3dPa55w0rd".into(),
2261 role: "viewer".into(),
2262 expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
2263 };
2264 let dbg = format!("{entry:?}");
2265
2266 assert!(dbg.contains("viewer-key"));
2268 assert!(dbg.contains("viewer"));
2269 assert!(dbg.contains("2030-01-01T00:00:00+00:00"));
2270
2271 assert!(
2273 !dbg.contains("$argon2id$"),
2274 "argon2 hash leaked into Debug output: {dbg}"
2275 );
2276 assert!(
2277 !dbg.contains("h4sh3dPa55w0rd"),
2278 "hash digest leaked into Debug output: {dbg}"
2279 );
2280 assert!(
2281 dbg.contains("<redacted>"),
2282 "redaction marker missing: {dbg}"
2283 );
2284 }
2285
2286 #[test]
2297 fn auth_failure_class_as_str_exact_strings() {
2298 assert_eq!(
2299 AuthFailureClass::MissingCredential.as_str(),
2300 "missing_credential"
2301 );
2302 assert_eq!(
2303 AuthFailureClass::InvalidCredential.as_str(),
2304 "invalid_credential"
2305 );
2306 assert_eq!(
2307 AuthFailureClass::ExpiredCredential.as_str(),
2308 "expired_credential"
2309 );
2310 assert_eq!(AuthFailureClass::RateLimited.as_str(), "rate_limited");
2311 assert_eq!(AuthFailureClass::PreAuthGate.as_str(), "pre_auth_gate");
2312 }
2313
2314 #[test]
2315 fn auth_failure_class_response_body_exact_strings() {
2316 assert_eq!(
2317 AuthFailureClass::MissingCredential.response_body(),
2318 "unauthorized: missing credential"
2319 );
2320 assert_eq!(
2321 AuthFailureClass::InvalidCredential.response_body(),
2322 "unauthorized: invalid credential"
2323 );
2324 assert_eq!(
2325 AuthFailureClass::ExpiredCredential.response_body(),
2326 "unauthorized: expired credential"
2327 );
2328 assert_eq!(
2329 AuthFailureClass::RateLimited.response_body(),
2330 "rate limited"
2331 );
2332 assert_eq!(
2333 AuthFailureClass::PreAuthGate.response_body(),
2334 "rate limited (pre-auth)"
2335 );
2336 }
2337
2338 #[test]
2339 fn auth_failure_class_bearer_error_exact_strings() {
2340 assert_eq!(
2341 AuthFailureClass::MissingCredential.bearer_error(),
2342 (
2343 "invalid_request",
2344 "missing bearer token or mTLS client certificate"
2345 )
2346 );
2347 assert_eq!(
2348 AuthFailureClass::InvalidCredential.bearer_error(),
2349 ("invalid_token", "token is invalid")
2350 );
2351 assert_eq!(
2352 AuthFailureClass::ExpiredCredential.bearer_error(),
2353 ("invalid_token", "token is expired")
2354 );
2355 assert_eq!(
2356 AuthFailureClass::RateLimited.bearer_error(),
2357 ("invalid_request", "too many failed authentication attempts")
2358 );
2359 assert_eq!(
2360 AuthFailureClass::PreAuthGate.bearer_error(),
2361 (
2362 "invalid_request",
2363 "too many unauthenticated requests from this source"
2364 )
2365 );
2366 }
2367
2368 #[test]
2377 fn auth_config_summary_bearer_true_when_keys_present() {
2378 let (_token, hash) = generate_api_key().unwrap();
2379 let cfg = AuthConfig::with_keys(vec![ApiKeyEntry::new("k", hash, "viewer")]);
2380 let s = cfg.summary();
2381 assert!(s.enabled, "summary.enabled must reflect AuthConfig.enabled");
2382 assert!(
2383 s.bearer,
2384 "summary.bearer must be true when api_keys is non-empty (kills `!` deletion at L615)"
2385 );
2386 assert!(!s.mtls, "summary.mtls must be false when mtls is None");
2387 assert!(!s.oauth, "summary.oauth must be false when oauth is None");
2388 assert_eq!(s.api_keys.len(), 1);
2389 assert_eq!(s.api_keys[0].name, "k");
2390 assert_eq!(s.api_keys[0].role, "viewer");
2391 }
2392
2393 #[test]
2394 fn auth_config_summary_bearer_false_when_no_keys() {
2395 let cfg = AuthConfig::with_keys(vec![]);
2396 let s = cfg.summary();
2397 assert!(
2398 !s.bearer,
2399 "summary.bearer must be false when api_keys is empty (kills `!` deletion at L615)"
2400 );
2401 assert!(s.api_keys.is_empty());
2402 }
2403
2404 #[test]
2405 fn seen_identity_set_first_then_repeat() {
2406 let set = SeenIdentitySet::new();
2407 assert!(set.insert_is_first("alice"), "first sighting is first");
2408 assert!(
2409 !set.insert_is_first("alice"),
2410 "second sighting is not first"
2411 );
2412 assert!(set.insert_is_first("bob"));
2413 assert_eq!(set.len(), 2);
2414 }
2415
2416 #[test]
2417 fn seen_identity_set_evicts_oldest_at_cap() {
2418 let set = SeenIdentitySet::with_cap(2);
2419 assert!(set.insert_is_first("a"));
2420 assert!(set.insert_is_first("b"));
2421 assert!(set.insert_is_first("c"));
2423 assert_eq!(set.len(), 2);
2424 assert!(set.insert_is_first("a"));
2428 assert_eq!(set.len(), 2);
2429 assert!(set.insert_is_first("b"));
2431 for i in 0..32 {
2433 set.insert_is_first(&format!("churn-{i}"));
2434 assert!(set.len() <= 2, "cap invariant must hold");
2435 }
2436 }
2437
2438 #[test]
2439 fn seen_identity_set_cap_zero_is_raised_to_one() {
2440 let set = SeenIdentitySet::with_cap(0);
2441 assert!(set.insert_is_first("only"));
2442 assert_eq!(set.len(), 1);
2443 assert!(set.insert_is_first("next"));
2445 assert_eq!(set.len(), 1);
2446 }
2447
2448 #[test]
2449 fn seen_identity_set_fifo_does_not_refresh_on_repeat_hit() {
2450 let set = SeenIdentitySet::with_cap(2);
2453 assert!(set.insert_is_first("a")); assert!(set.insert_is_first("b")); assert!(!set.insert_is_first("a"));
2459 assert!(set.insert_is_first("c"));
2462 assert!(set.insert_is_first("a"));
2464 let set = SeenIdentitySet::with_cap(2);
2470 assert!(set.insert_is_first("x")); assert!(set.insert_is_first("y")); assert!(!set.insert_is_first("x")); assert!(set.insert_is_first("z")); assert!(
2475 !set.insert_is_first("y"),
2476 "y must still be present (FIFO did not evict it)"
2477 );
2478 assert!(
2479 set.insert_is_first("x"),
2480 "x must have been evicted by FIFO (would NOT have been evicted under LRU)"
2481 );
2482 }
2483}