1use std::{
10 collections::HashSet,
11 net::{IpAddr, SocketAddr},
12 num::NonZeroU32,
13 path::PathBuf,
14 sync::{
15 Arc, LazyLock, Mutex,
16 atomic::{AtomicU64, Ordering},
17 },
18 time::Duration,
19};
20
21use arc_swap::ArcSwap;
22use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
23use axum::{
24 body::Body,
25 extract::ConnectInfo,
26 http::{Request, header},
27 middleware::Next,
28 response::{IntoResponse, Response},
29};
30use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
31use secrecy::SecretString;
32use serde::Deserialize;
33use x509_parser::prelude::*;
34
35use crate::{bounded_limiter::BoundedKeyedLimiter, error::McpxError};
36
37#[derive(Clone)]
46#[non_exhaustive]
47pub struct AuthIdentity {
48 pub name: String,
50 pub role: String,
52 pub method: AuthMethod,
54 pub raw_token: Option<SecretString>,
60 pub sub: Option<String>,
63}
64
65impl std::fmt::Debug for AuthIdentity {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("AuthIdentity")
70 .field("name", &self.name)
71 .field("role", &self.role)
72 .field("method", &self.method)
73 .field(
74 "raw_token",
75 &if self.raw_token.is_some() {
76 "<redacted>"
77 } else {
78 "<none>"
79 },
80 )
81 .field(
82 "sub",
83 &if self.sub.is_some() {
84 "<redacted>"
85 } else {
86 "<none>"
87 },
88 )
89 .finish()
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95#[non_exhaustive]
96pub enum AuthMethod {
97 BearerToken,
99 MtlsCertificate,
101 OAuthJwt,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106enum AuthFailureClass {
107 MissingCredential,
108 InvalidCredential,
109 #[cfg_attr(not(feature = "oauth"), allow(dead_code))]
110 ExpiredCredential,
111 RateLimited,
113 PreAuthGate,
116}
117
118impl AuthFailureClass {
119 fn as_str(self) -> &'static str {
120 match self {
121 Self::MissingCredential => "missing_credential",
122 Self::InvalidCredential => "invalid_credential",
123 Self::ExpiredCredential => "expired_credential",
124 Self::RateLimited => "rate_limited",
125 Self::PreAuthGate => "pre_auth_gate",
126 }
127 }
128
129 fn bearer_error(self) -> (&'static str, &'static str) {
130 match self {
131 Self::MissingCredential => (
132 "invalid_request",
133 "missing bearer token or mTLS client certificate",
134 ),
135 Self::InvalidCredential => ("invalid_token", "token is invalid"),
136 Self::ExpiredCredential => ("invalid_token", "token is expired"),
137 Self::RateLimited => ("invalid_request", "too many failed authentication attempts"),
138 Self::PreAuthGate => (
139 "invalid_request",
140 "too many unauthenticated requests from this source",
141 ),
142 }
143 }
144
145 fn response_body(self) -> &'static str {
146 match self {
147 Self::MissingCredential => "unauthorized: missing credential",
148 Self::InvalidCredential => "unauthorized: invalid credential",
149 Self::ExpiredCredential => "unauthorized: expired credential",
150 Self::RateLimited => "rate limited",
151 Self::PreAuthGate => "rate limited (pre-auth)",
152 }
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
158#[non_exhaustive]
159pub struct AuthCountersSnapshot {
160 pub success_mtls: u64,
162 pub success_bearer: u64,
164 pub success_oauth_jwt: u64,
166 pub failure_missing_credential: u64,
168 pub failure_invalid_credential: u64,
170 pub failure_expired_credential: u64,
172 pub failure_rate_limited: u64,
174 pub failure_pre_auth_gate: u64,
177}
178
179#[derive(Debug, Default)]
181pub(crate) struct AuthCounters {
182 success_mtls: AtomicU64,
183 success_bearer: AtomicU64,
184 success_oauth_jwt: AtomicU64,
185 failure_missing_credential: AtomicU64,
186 failure_invalid_credential: AtomicU64,
187 failure_expired_credential: AtomicU64,
188 failure_rate_limited: AtomicU64,
189 failure_pre_auth_gate: AtomicU64,
190}
191
192impl AuthCounters {
193 fn record_success(&self, method: AuthMethod) {
194 match method {
195 AuthMethod::MtlsCertificate => {
196 self.success_mtls.fetch_add(1, Ordering::Relaxed);
197 }
198 AuthMethod::BearerToken => {
199 self.success_bearer.fetch_add(1, Ordering::Relaxed);
200 }
201 AuthMethod::OAuthJwt => {
202 self.success_oauth_jwt.fetch_add(1, Ordering::Relaxed);
203 }
204 }
205 }
206
207 fn record_failure(&self, class: AuthFailureClass) {
208 match class {
209 AuthFailureClass::MissingCredential => {
210 self.failure_missing_credential
211 .fetch_add(1, Ordering::Relaxed);
212 }
213 AuthFailureClass::InvalidCredential => {
214 self.failure_invalid_credential
215 .fetch_add(1, Ordering::Relaxed);
216 }
217 AuthFailureClass::ExpiredCredential => {
218 self.failure_expired_credential
219 .fetch_add(1, Ordering::Relaxed);
220 }
221 AuthFailureClass::RateLimited => {
222 self.failure_rate_limited.fetch_add(1, Ordering::Relaxed);
223 }
224 AuthFailureClass::PreAuthGate => {
225 self.failure_pre_auth_gate.fetch_add(1, Ordering::Relaxed);
226 }
227 }
228 }
229
230 fn snapshot(&self) -> AuthCountersSnapshot {
231 AuthCountersSnapshot {
232 success_mtls: self.success_mtls.load(Ordering::Relaxed),
233 success_bearer: self.success_bearer.load(Ordering::Relaxed),
234 success_oauth_jwt: self.success_oauth_jwt.load(Ordering::Relaxed),
235 failure_missing_credential: self.failure_missing_credential.load(Ordering::Relaxed),
236 failure_invalid_credential: self.failure_invalid_credential.load(Ordering::Relaxed),
237 failure_expired_credential: self.failure_expired_credential.load(Ordering::Relaxed),
238 failure_rate_limited: self.failure_rate_limited.load(Ordering::Relaxed),
239 failure_pre_auth_gate: self.failure_pre_auth_gate.load(Ordering::Relaxed),
240 }
241 }
242}
243
244#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
256#[non_exhaustive]
257pub struct RfcTimestamp(chrono::DateTime<chrono::FixedOffset>);
258
259impl RfcTimestamp {
260 pub fn parse(s: &str) -> Result<Self, chrono::ParseError> {
268 chrono::DateTime::parse_from_rfc3339(s).map(Self)
269 }
270
271 #[must_use]
273 pub fn as_datetime(&self) -> &chrono::DateTime<chrono::FixedOffset> {
274 &self.0
275 }
276
277 #[must_use]
279 pub fn into_inner(self) -> chrono::DateTime<chrono::FixedOffset> {
280 self.0
281 }
282}
283
284impl std::fmt::Display for RfcTimestamp {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 write!(f, "{}", self.0.to_rfc3339())
288 }
289}
290
291impl std::fmt::Debug for RfcTimestamp {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 write!(f, "{}", self.0.to_rfc3339())
298 }
299}
300
301impl<'de> Deserialize<'de> for RfcTimestamp {
302 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
303 where
304 D: serde::Deserializer<'de>,
305 {
306 let s = String::deserialize(deserializer)?;
310 Self::parse(&s).map_err(serde::de::Error::custom)
311 }
312}
313
314impl serde::Serialize for RfcTimestamp {
315 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
316 where
317 S: serde::Serializer,
318 {
319 serializer.serialize_str(&self.0.to_rfc3339())
320 }
321}
322
323impl From<chrono::DateTime<chrono::FixedOffset>> for RfcTimestamp {
324 fn from(value: chrono::DateTime<chrono::FixedOffset>) -> Self {
325 Self(value)
326 }
327}
328
329#[derive(Clone, Deserialize)]
336#[non_exhaustive]
337pub struct ApiKeyEntry {
338 pub name: String,
340 pub hash: String,
342 pub role: String,
344 pub expires_at: Option<RfcTimestamp>,
349}
350
351impl std::fmt::Debug for ApiKeyEntry {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 f.debug_struct("ApiKeyEntry")
356 .field("name", &self.name)
357 .field("hash", &"<redacted>")
358 .field("role", &self.role)
359 .field("expires_at", &self.expires_at)
360 .finish()
361 }
362}
363
364impl ApiKeyEntry {
365 #[must_use]
367 pub fn new(name: impl Into<String>, hash: impl Into<String>, role: impl Into<String>) -> Self {
368 Self {
369 name: name.into(),
370 hash: hash.into(),
371 role: role.into(),
372 expires_at: None,
373 }
374 }
375
376 #[must_use]
381 pub fn with_expiry(mut self, expires_at: RfcTimestamp) -> Self {
382 self.expires_at = Some(expires_at);
383 self
384 }
385
386 pub fn try_with_expiry(
394 mut self,
395 expires_at: impl AsRef<str>,
396 ) -> Result<Self, chrono::ParseError> {
397 self.expires_at = Some(RfcTimestamp::parse(expires_at.as_ref())?);
398 Ok(self)
399 }
400}
401
402#[derive(Debug, Clone, Deserialize)]
404#[allow(
405 clippy::struct_excessive_bools,
406 reason = "mTLS CRL behavior is intentionally configured as independent booleans"
407)]
408#[non_exhaustive]
409pub struct MtlsConfig {
410 pub ca_cert_path: PathBuf,
412 #[serde(default)]
415 pub required: bool,
416 #[serde(default = "default_mtls_role")]
419 pub default_role: String,
420 #[serde(default = "default_true")]
423 pub crl_enabled: bool,
424 #[serde(default, with = "humantime_serde::option")]
427 pub crl_refresh_interval: Option<Duration>,
428 #[serde(default = "default_crl_fetch_timeout", with = "humantime_serde")]
430 pub crl_fetch_timeout: Duration,
431 #[serde(default = "default_crl_stale_grace", with = "humantime_serde")]
434 pub crl_stale_grace: Duration,
435 #[serde(default)]
438 pub crl_deny_on_unavailable: bool,
439 #[serde(default)]
441 pub crl_end_entity_only: bool,
442 #[serde(default = "default_true")]
451 pub crl_allow_http: bool,
452 #[serde(default = "default_true")]
454 pub crl_enforce_expiration: bool,
455 #[serde(default = "default_crl_max_concurrent_fetches")]
461 pub crl_max_concurrent_fetches: usize,
462 #[serde(default = "default_crl_max_response_bytes")]
466 pub crl_max_response_bytes: u64,
467 #[serde(default = "default_crl_discovery_rate_per_min")]
483 pub crl_discovery_rate_per_min: u32,
484 #[serde(default = "default_crl_max_host_semaphores")]
493 pub crl_max_host_semaphores: usize,
494 #[serde(default = "default_crl_max_seen_urls")]
498 pub crl_max_seen_urls: usize,
499 #[serde(default = "default_crl_max_cache_entries")]
503 pub crl_max_cache_entries: usize,
504}
505
506fn default_mtls_role() -> String {
507 "viewer".into()
508}
509
510const fn default_true() -> bool {
511 true
512}
513
514const fn default_crl_fetch_timeout() -> Duration {
515 Duration::from_secs(30)
516}
517
518const fn default_crl_stale_grace() -> Duration {
519 Duration::from_hours(24)
520}
521
522const fn default_crl_max_concurrent_fetches() -> usize {
523 4
524}
525
526const fn default_crl_max_response_bytes() -> u64 {
527 5 * 1024 * 1024
528}
529
530const fn default_crl_discovery_rate_per_min() -> u32 {
531 60
532}
533
534const fn default_crl_max_host_semaphores() -> usize {
535 1024
536}
537
538const fn default_crl_max_seen_urls() -> usize {
539 4096
540}
541
542const fn default_crl_max_cache_entries() -> usize {
543 1024
544}
545
546#[derive(Debug, Clone, Deserialize)]
561#[non_exhaustive]
562pub struct RateLimitConfig {
563 #[serde(default = "default_max_attempts")]
566 pub max_attempts_per_minute: u32,
567 #[serde(default)]
575 pub pre_auth_max_per_minute: Option<u32>,
576 #[serde(default = "default_max_tracked_keys")]
581 pub max_tracked_keys: usize,
582 #[serde(default = "default_idle_eviction", with = "humantime_serde")]
585 pub idle_eviction: Duration,
586}
587
588impl Default for RateLimitConfig {
589 fn default() -> Self {
590 Self {
591 max_attempts_per_minute: default_max_attempts(),
592 pre_auth_max_per_minute: None,
593 max_tracked_keys: default_max_tracked_keys(),
594 idle_eviction: default_idle_eviction(),
595 }
596 }
597}
598
599impl RateLimitConfig {
600 #[must_use]
604 pub fn new(max_attempts_per_minute: u32) -> Self {
605 Self {
606 max_attempts_per_minute,
607 ..Self::default()
608 }
609 }
610
611 #[must_use]
614 pub fn with_pre_auth_max_per_minute(mut self, quota: u32) -> Self {
615 self.pre_auth_max_per_minute = Some(quota);
616 self
617 }
618
619 #[must_use]
621 pub fn with_max_tracked_keys(mut self, max: usize) -> Self {
622 self.max_tracked_keys = max;
623 self
624 }
625
626 #[must_use]
628 pub fn with_idle_eviction(mut self, idle: Duration) -> Self {
629 self.idle_eviction = idle;
630 self
631 }
632}
633
634fn default_max_attempts() -> u32 {
635 30
636}
637
638fn default_max_tracked_keys() -> usize {
639 10_000
640}
641
642fn default_idle_eviction() -> Duration {
643 Duration::from_mins(15)
644}
645
646#[derive(Debug, Clone, Default, Deserialize)]
648#[non_exhaustive]
649pub struct AuthConfig {
650 #[serde(default)]
652 pub enabled: bool,
653 #[serde(default)]
655 pub api_keys: Vec<ApiKeyEntry>,
656 pub mtls: Option<MtlsConfig>,
658 pub rate_limit: Option<RateLimitConfig>,
660 #[cfg(feature = "oauth")]
662 pub oauth: Option<crate::oauth::OAuthConfig>,
663}
664
665impl AuthConfig {
666 #[must_use]
668 pub fn with_keys(keys: Vec<ApiKeyEntry>) -> Self {
669 Self {
670 enabled: true,
671 api_keys: keys,
672 mtls: None,
673 rate_limit: None,
674 #[cfg(feature = "oauth")]
675 oauth: None,
676 }
677 }
678
679 #[must_use]
681 pub fn with_rate_limit(mut self, rate_limit: RateLimitConfig) -> Self {
682 self.rate_limit = Some(rate_limit);
683 self
684 }
685}
686
687#[derive(Debug, Clone, serde::Serialize)]
691#[non_exhaustive]
692pub struct ApiKeySummary {
693 pub name: String,
695 pub role: String,
697 pub expires_at: Option<RfcTimestamp>,
700}
701
702#[derive(Debug, Clone, serde::Serialize)]
704#[allow(
705 clippy::struct_excessive_bools,
706 reason = "this is a flat summary of independent auth-method booleans"
707)]
708#[non_exhaustive]
709pub struct AuthConfigSummary {
710 pub enabled: bool,
712 pub bearer: bool,
714 pub mtls: bool,
716 pub oauth: bool,
718 pub api_keys: Vec<ApiKeySummary>,
720}
721
722impl AuthConfig {
723 #[must_use]
725 pub fn summary(&self) -> AuthConfigSummary {
726 AuthConfigSummary {
727 enabled: self.enabled,
728 bearer: !self.api_keys.is_empty(),
729 mtls: self.mtls.is_some(),
730 #[cfg(feature = "oauth")]
731 oauth: self.oauth.is_some(),
732 #[cfg(not(feature = "oauth"))]
733 oauth: false,
734 api_keys: self
735 .api_keys
736 .iter()
737 .map(|k| ApiKeySummary {
738 name: k.name.clone(),
739 role: k.role.clone(),
740 expires_at: k.expires_at,
741 })
742 .collect(),
743 }
744 }
745}
746
747pub(crate) type KeyedLimiter = BoundedKeyedLimiter<IpAddr>;
750
751#[derive(Clone, Debug)]
761#[non_exhaustive]
762pub(crate) struct TlsConnInfo {
763 pub addr: SocketAddr,
765 pub identity: Option<AuthIdentity>,
768}
769
770impl TlsConnInfo {
771 #[must_use]
773 pub(crate) const fn new(addr: SocketAddr, identity: Option<AuthIdentity>) -> Self {
774 Self { addr, identity }
775 }
776}
777
778const DEFAULT_SEEN_IDENTITY_CAP: usize = 4096;
786
787pub(crate) struct SeenIdentitySet {
807 inner: Mutex<SeenInner>,
808}
809
810struct SeenInner {
811 set: HashSet<String>,
812 order: std::collections::VecDeque<String>,
817 cap: usize,
818}
819
820impl SeenIdentitySet {
821 #[must_use]
823 pub(crate) fn new() -> Self {
824 Self::with_cap(DEFAULT_SEEN_IDENTITY_CAP)
825 }
826
827 #[must_use]
830 pub(crate) fn with_cap(cap: usize) -> Self {
831 let cap = cap.max(1);
832 Self {
833 inner: Mutex::new(SeenInner {
834 set: HashSet::with_capacity(cap.min(64)),
835 order: std::collections::VecDeque::with_capacity(cap.min(64)),
836 cap,
837 }),
838 }
839 }
840
841 pub(crate) fn insert_is_first(&self, name: &str) -> bool {
848 let mut guard = self
854 .inner
855 .lock()
856 .unwrap_or_else(std::sync::PoisonError::into_inner);
857
858 if guard.set.contains(name) {
859 return false;
860 }
861 if guard.set.len() >= guard.cap
864 && let Some(evicted) = guard.order.pop_front()
865 {
866 guard.set.remove(&evicted);
867 }
868 let owned = name.to_owned();
869 guard.set.insert(owned.clone());
870 guard.order.push_back(owned);
871 true
872 }
873
874 #[cfg(test)]
876 pub(crate) fn len(&self) -> usize {
877 self.inner
878 .lock()
879 .unwrap_or_else(std::sync::PoisonError::into_inner)
880 .set
881 .len()
882 }
883}
884
885impl Default for SeenIdentitySet {
886 fn default() -> Self {
887 Self::new()
888 }
889}
890
891#[allow(
896 missing_debug_implementations,
897 reason = "contains governor RateLimiter and JwksCache without Debug impls"
898)]
899#[non_exhaustive]
900pub(crate) struct AuthState {
901 pub api_keys: ArcSwap<Vec<ApiKeyEntry>>,
903 pub rate_limiter: Option<Arc<KeyedLimiter>>,
905 pub pre_auth_limiter: Option<Arc<KeyedLimiter>>,
908 #[cfg(feature = "oauth")]
909 pub jwks_cache: Option<Arc<crate::oauth::JwksCache>>,
911 pub seen_identities: SeenIdentitySet,
916 pub counters: AuthCounters,
918}
919
920impl AuthState {
921 pub(crate) fn reload_keys(&self, keys: Vec<ApiKeyEntry>) {
927 let count = keys.len();
928 self.api_keys.store(Arc::new(keys));
929 tracing::info!(keys = count, "API keys reloaded");
930 }
931
932 #[must_use]
934 pub(crate) fn counters_snapshot(&self) -> AuthCountersSnapshot {
935 self.counters.snapshot()
936 }
937
938 #[must_use]
940 pub(crate) fn api_key_summaries(&self) -> Vec<ApiKeySummary> {
941 self.api_keys
942 .load()
943 .iter()
944 .map(|k| ApiKeySummary {
945 name: k.name.clone(),
946 role: k.role.clone(),
947 expires_at: k.expires_at,
948 })
949 .collect()
950 }
951
952 fn log_auth(&self, id: &AuthIdentity, method: &str) {
960 self.counters.record_success(id.method);
961 let first = self.seen_identities.insert_is_first(&id.name);
962 if first {
963 tracing::info!(name = %id.name, role = %id.role, "{method} authenticated");
964 } else {
965 tracing::debug!(name = %id.name, role = %id.role, "{method} authenticated");
966 }
967 }
968}
969
970const DEFAULT_AUTH_RATE: NonZeroU32 = NonZeroU32::new(30).unwrap();
973
974#[must_use]
976pub(crate) fn build_rate_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
977 let quota = governor::Quota::per_minute(
978 NonZeroU32::new(config.max_attempts_per_minute).unwrap_or(DEFAULT_AUTH_RATE),
979 );
980 Arc::new(BoundedKeyedLimiter::new(
981 quota,
982 config.max_tracked_keys,
983 config.idle_eviction,
984 ))
985}
986
987#[must_use]
994pub(crate) fn build_pre_auth_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
995 let resolved = config.pre_auth_max_per_minute.unwrap_or_else(|| {
996 config
997 .max_attempts_per_minute
998 .saturating_mul(PRE_AUTH_DEFAULT_MULTIPLIER)
999 });
1000 let quota =
1001 governor::Quota::per_minute(NonZeroU32::new(resolved).unwrap_or(DEFAULT_PRE_AUTH_RATE));
1002 Arc::new(BoundedKeyedLimiter::new(
1003 quota,
1004 config.max_tracked_keys,
1005 config.idle_eviction,
1006 ))
1007}
1008
1009const PRE_AUTH_DEFAULT_MULTIPLIER: u32 = 10;
1012
1013const DEFAULT_PRE_AUTH_RATE: NonZeroU32 = NonZeroU32::new(300).unwrap();
1017
1018#[must_use]
1023pub fn extract_mtls_identity(cert_der: &[u8], default_role: &str) -> Option<AuthIdentity> {
1024 let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
1025
1026 let cn = cert
1028 .subject()
1029 .iter_common_name()
1030 .next()
1031 .and_then(|attr| attr.as_str().ok())
1032 .map(String::from);
1033
1034 let name = cn.or_else(|| {
1036 cert.subject_alternative_name()
1037 .ok()
1038 .flatten()
1039 .and_then(|san| {
1040 #[allow(
1041 clippy::wildcard_enum_match_arm,
1042 reason = "x509-parser GeneralName is a large external enum; only DNSName is meaningful here"
1043 )]
1044 san.value.general_names.iter().find_map(|gn| match gn {
1045 GeneralName::DNSName(dns) => Some((*dns).to_owned()),
1046 _ => None,
1047 })
1048 })
1049 })?;
1050
1051 if !name
1053 .chars()
1054 .all(|c| c.is_alphanumeric() || matches!(c, '-' | '.' | '_' | '@'))
1055 {
1056 tracing::warn!(cn = %name, "mTLS identity rejected: invalid characters in CN/SAN");
1057 return None;
1058 }
1059
1060 Some(AuthIdentity {
1061 name,
1062 role: default_role.to_owned(),
1063 method: AuthMethod::MtlsCertificate,
1064 raw_token: None,
1065 sub: None,
1066 })
1067}
1068
1069fn extract_bearer(value: &str) -> Option<&str> {
1084 let (scheme, rest) = value.split_once(' ')?;
1085 if scheme.eq_ignore_ascii_case("Bearer") {
1086 let token = rest.trim_start_matches(' ');
1087 if token.is_empty() { None } else { Some(token) }
1088 } else {
1089 None
1090 }
1091}
1092
1093#[must_use]
1122pub fn verify_bearer_token(token: &str, keys: &[ApiKeyEntry]) -> Option<AuthIdentity> {
1123 use subtle::ConstantTimeEq as _;
1124
1125 let now = chrono::Utc::now();
1126 #[allow(
1127 clippy::expect_used,
1128 reason = "DUMMY_PHC_HASH is a static LazyLock built from a fixed Argon2id PHC string by construction; PasswordHash::new on it is infallible. See DUMMY_PHC_HASH definition."
1129 )]
1130 let dummy_hash = PasswordHash::new(&DUMMY_PHC_HASH)
1131 .expect("DUMMY_PHC_HASH is a valid Argon2id PHC string by construction");
1132
1133 let mut matched_index: usize = usize::MAX;
1134 let mut any_match: u8 = 0;
1135
1136 for (idx, key) in keys.iter().enumerate() {
1137 let expired = key.expires_at.is_some_and(|exp| exp.as_datetime() < &now);
1138
1139 let real_hash = PasswordHash::new(&key.hash);
1140 let verify_against = match (&real_hash, expired, any_match) {
1141 (Ok(h), false, 0) => h,
1142 _ => &dummy_hash,
1143 };
1144
1145 let slot_ok = u8::from(
1146 Argon2::default()
1147 .verify_password(token.as_bytes(), verify_against)
1148 .is_ok(),
1149 );
1150
1151 let real_match = slot_ok & u8::from(!expired) & u8::from(real_hash.is_ok());
1152 let first_real_match = real_match & (1 - any_match);
1153 if first_real_match.ct_eq(&1).into() {
1154 matched_index = idx;
1155 }
1156 any_match |= real_match;
1157 }
1158
1159 if any_match == 0 {
1160 return None;
1161 }
1162 let key = keys.get(matched_index)?;
1163 Some(AuthIdentity {
1164 name: key.name.clone(),
1165 role: key.role.clone(),
1166 method: AuthMethod::BearerToken,
1167 raw_token: None,
1168 sub: None,
1169 })
1170}
1171
1172static DUMMY_PHC_HASH: LazyLock<String> = LazyLock::new(|| {
1185 #[allow(
1187 clippy::expect_used,
1188 reason = "fixed 22-char base64 ('AAAA...') decodes to a valid 16-byte salt; SaltString::from_b64 is infallible on this literal"
1189 )]
1190 let salt = SaltString::from_b64("AAAAAAAAAAAAAAAAAAAAAA")
1191 .expect("fixed 16-byte base64 salt is well-formed");
1192 #[allow(
1193 clippy::expect_used,
1194 reason = "Argon2::default() with a fixed plaintext and a well-formed salt is infallible; only fails on bad params/salt"
1195 )]
1196 Argon2::default()
1197 .hash_password(b"rmcp-server-kit-dummy", &salt)
1198 .expect("Argon2 default params hash a fixed plaintext")
1199 .to_string()
1200});
1201
1202pub fn generate_api_key() -> Result<(String, String), McpxError> {
1212 let mut token_bytes = [0u8; 32];
1213 rand::fill(&mut token_bytes);
1214 let token = URL_SAFE_NO_PAD.encode(token_bytes);
1215
1216 let mut salt_bytes = [0u8; 16];
1218 rand::fill(&mut salt_bytes);
1219 let salt = SaltString::encode_b64(&salt_bytes)
1220 .map_err(|e| McpxError::Auth(format!("salt encoding failed: {e}")))?;
1221 let hash = Argon2::default()
1222 .hash_password(token.as_bytes(), &salt)
1223 .map_err(|e| McpxError::Auth(format!("argon2id hashing failed: {e}")))?
1224 .to_string();
1225
1226 Ok((token, hash))
1227}
1228
1229fn build_www_authenticate_value(
1230 advertise_resource_metadata: bool,
1231 failure: AuthFailureClass,
1232) -> String {
1233 let (error, error_description) = failure.bearer_error();
1234 if advertise_resource_metadata {
1235 return format!(
1236 "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\", error=\"{error}\", error_description=\"{error_description}\""
1237 );
1238 }
1239 format!("Bearer error=\"{error}\", error_description=\"{error_description}\"")
1240}
1241
1242fn auth_method_label(method: AuthMethod) -> &'static str {
1243 match method {
1244 AuthMethod::MtlsCertificate => "mTLS",
1245 AuthMethod::BearerToken => "bearer token",
1246 AuthMethod::OAuthJwt => "OAuth JWT",
1247 }
1248}
1249
1250#[cfg_attr(not(feature = "oauth"), allow(unused_variables))]
1251fn unauthorized_response(state: &AuthState, failure_class: AuthFailureClass) -> Response {
1252 #[cfg(feature = "oauth")]
1253 let advertise_resource_metadata = state.jwks_cache.is_some();
1254 #[cfg(not(feature = "oauth"))]
1255 let advertise_resource_metadata = false;
1256
1257 let challenge = build_www_authenticate_value(advertise_resource_metadata, failure_class);
1258 (
1259 axum::http::StatusCode::UNAUTHORIZED,
1260 [(header::WWW_AUTHENTICATE, challenge)],
1261 failure_class.response_body(),
1262 )
1263 .into_response()
1264}
1265
1266async fn authenticate_bearer_identity(
1267 state: &AuthState,
1268 token: &str,
1269) -> Result<AuthIdentity, AuthFailureClass> {
1270 let mut failure_class = AuthFailureClass::MissingCredential;
1271
1272 #[cfg(feature = "oauth")]
1273 if let Some(ref cache) = state.jwks_cache
1274 && crate::oauth::looks_like_jwt(token)
1275 {
1276 match cache.validate_token_with_reason(token).await {
1277 Ok(mut id) => {
1278 id.raw_token = Some(SecretString::from(token.to_owned()));
1279 return Ok(id);
1280 }
1281 Err(crate::oauth::JwtValidationFailure::Expired) => {
1282 failure_class = AuthFailureClass::ExpiredCredential;
1283 }
1284 Err(crate::oauth::JwtValidationFailure::Invalid) => {
1285 failure_class = AuthFailureClass::InvalidCredential;
1286 }
1287 }
1288 }
1289
1290 let token = token.to_owned();
1291 let keys = state.api_keys.load_full(); let identity = tokio::task::spawn_blocking(move || verify_bearer_token(&token, &keys))
1295 .await
1296 .ok()
1297 .flatten();
1298
1299 if let Some(id) = identity {
1300 return Ok(id);
1301 }
1302
1303 if failure_class == AuthFailureClass::MissingCredential {
1304 failure_class = AuthFailureClass::InvalidCredential;
1305 }
1306
1307 Err(failure_class)
1308}
1309
1310fn pre_auth_gate(state: &AuthState, peer_addr: Option<SocketAddr>) -> Option<Response> {
1321 let limiter = state.pre_auth_limiter.as_ref()?;
1322 let addr = peer_addr?;
1323 if limiter.check_key(&addr.ip()).is_ok() {
1324 return None;
1325 }
1326 state.counters.record_failure(AuthFailureClass::PreAuthGate);
1327 tracing::warn!(
1328 ip = %addr.ip(),
1329 "auth rate limited by pre-auth gate (request rejected before credential verification)"
1330 );
1331 Some(
1332 McpxError::RateLimited("too many unauthenticated requests from this source".into())
1333 .into_response(),
1334 )
1335}
1336
1337pub(crate) async fn auth_middleware(
1346 state: Arc<AuthState>,
1347 req: Request<Body>,
1348 next: Next,
1349) -> Response {
1350 let tls_info = req.extensions().get::<ConnectInfo<TlsConnInfo>>().cloned();
1355 let peer_addr = req
1356 .extensions()
1357 .get::<ConnectInfo<SocketAddr>>()
1358 .map(|ci| ci.0)
1359 .or_else(|| tls_info.as_ref().map(|ci| ci.0.addr));
1360
1361 if let Some(id) = tls_info.and_then(|ci| ci.0.identity) {
1368 state.log_auth(&id, "mTLS");
1369 let mut req = req;
1370 req.extensions_mut().insert(id);
1371 return next.run(req).await;
1372 }
1373
1374 if let Some(blocked) = pre_auth_gate(&state, peer_addr) {
1378 return blocked;
1379 }
1380
1381 let failure_class = if let Some(value) = req.headers().get(header::AUTHORIZATION) {
1382 match value.to_str().ok().and_then(extract_bearer) {
1383 Some(token) => match authenticate_bearer_identity(&state, token).await {
1384 Ok(id) => {
1385 state.log_auth(&id, auth_method_label(id.method));
1386 let mut req = req;
1387 req.extensions_mut().insert(id);
1388 return next.run(req).await;
1389 }
1390 Err(class) => class,
1391 },
1392 None => AuthFailureClass::InvalidCredential,
1393 }
1394 } else {
1395 AuthFailureClass::MissingCredential
1396 };
1397
1398 tracing::warn!(failure_class = %failure_class.as_str(), "auth failed");
1399
1400 if let (Some(limiter), Some(addr)) = (&state.rate_limiter, peer_addr)
1403 && limiter.check_key(&addr.ip()).is_err()
1404 {
1405 state.counters.record_failure(AuthFailureClass::RateLimited);
1406 tracing::warn!(ip = %addr.ip(), "auth rate limited after repeated failures");
1407 return McpxError::RateLimited("too many failed authentication attempts".into())
1408 .into_response();
1409 }
1410
1411 state.counters.record_failure(failure_class);
1412 unauthorized_response(&state, failure_class)
1413}
1414
1415#[cfg(test)]
1416mod tests {
1417 use super::*;
1418
1419 #[test]
1420 fn generate_and_verify_api_key() {
1421 let (token, hash) = generate_api_key().unwrap();
1422
1423 assert_eq!(token.len(), 43);
1425
1426 assert!(hash.starts_with("$argon2id$"));
1428
1429 let keys = vec![ApiKeyEntry {
1431 name: "test".into(),
1432 hash,
1433 role: "viewer".into(),
1434 expires_at: None,
1435 }];
1436 let id = verify_bearer_token(&token, &keys);
1437 assert!(id.is_some());
1438 let id = id.unwrap();
1439 assert_eq!(id.name, "test");
1440 assert_eq!(id.role, "viewer");
1441 assert_eq!(id.method, AuthMethod::BearerToken);
1442 }
1443
1444 #[test]
1445 fn wrong_token_rejected() {
1446 let (_token, hash) = generate_api_key().unwrap();
1447 let keys = vec![ApiKeyEntry {
1448 name: "test".into(),
1449 hash,
1450 role: "viewer".into(),
1451 expires_at: None,
1452 }];
1453 assert!(verify_bearer_token("wrong-token", &keys).is_none());
1454 }
1455
1456 #[test]
1457 fn expired_key_rejected() {
1458 let (token, hash) = generate_api_key().unwrap();
1459 let keys = vec![ApiKeyEntry {
1460 name: "test".into(),
1461 hash,
1462 role: "viewer".into(),
1463 expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1464 }];
1465 assert!(verify_bearer_token(&token, &keys).is_none());
1466 }
1467
1468 #[test]
1469 fn match_in_last_slot_still_authenticates() {
1470 let (token, hash) = generate_api_key().unwrap();
1471 let (_other_token, other_hash) = generate_api_key().unwrap();
1472 let keys = vec![
1473 ApiKeyEntry {
1474 name: "first".into(),
1475 hash: other_hash.clone(),
1476 role: "viewer".into(),
1477 expires_at: None,
1478 },
1479 ApiKeyEntry {
1480 name: "second".into(),
1481 hash: other_hash,
1482 role: "viewer".into(),
1483 expires_at: None,
1484 },
1485 ApiKeyEntry {
1486 name: "match".into(),
1487 hash,
1488 role: "ops".into(),
1489 expires_at: None,
1490 },
1491 ];
1492 let id = verify_bearer_token(&token, &keys).expect("last-slot match must authenticate");
1493 assert_eq!(id.name, "match");
1494 assert_eq!(id.role, "ops");
1495 }
1496
1497 #[test]
1498 fn expired_slot_before_valid_match_does_not_short_circuit() {
1499 let (token, hash) = generate_api_key().unwrap();
1500 let (_, other_hash) = generate_api_key().unwrap();
1501 let keys = vec![
1502 ApiKeyEntry {
1503 name: "expired".into(),
1504 hash: other_hash,
1505 role: "viewer".into(),
1506 expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1507 },
1508 ApiKeyEntry {
1509 name: "valid".into(),
1510 hash,
1511 role: "ops".into(),
1512 expires_at: None,
1513 },
1514 ];
1515 let id = verify_bearer_token(&token, &keys)
1516 .expect("valid slot following an expired slot must authenticate");
1517 assert_eq!(id.name, "valid");
1518 }
1519
1520 #[test]
1521 fn malformed_hash_slot_does_not_short_circuit() {
1522 let (token, hash) = generate_api_key().unwrap();
1523 let keys = vec![
1524 ApiKeyEntry {
1525 name: "broken".into(),
1526 hash: "this-is-not-a-phc-string".into(),
1527 role: "viewer".into(),
1528 expires_at: None,
1529 },
1530 ApiKeyEntry {
1531 name: "valid".into(),
1532 hash,
1533 role: "ops".into(),
1534 expires_at: None,
1535 },
1536 ];
1537 let id = verify_bearer_token(&token, &keys)
1538 .expect("valid slot following a malformed-hash slot must authenticate");
1539 assert_eq!(id.name, "valid");
1540 }
1541
1542 #[test]
1553 fn rfc_timestamp_parse_rejects_malformed() {
1554 for bad in [
1555 "not-a-date",
1556 "",
1557 "2025-13-01T00:00:00Z", "2025-01-32T00:00:00Z", "2025-01-01T00:00:00", "01/01/2025", "2025-01-01T25:00:00Z", ] {
1563 assert!(
1564 RfcTimestamp::parse(bad).is_err(),
1565 "RfcTimestamp::parse must reject {bad:?}"
1566 );
1567 }
1568 }
1569
1570 #[test]
1571 fn rfc_timestamp_parse_accepts_valid() {
1572 for good in [
1573 "2025-01-01T00:00:00Z",
1574 "2025-01-01T00:00:00+00:00",
1575 "2025-12-31T23:59:59-08:00",
1576 "2099-01-01T00:00:00.123456789Z",
1577 ] {
1578 assert!(
1579 RfcTimestamp::parse(good).is_ok(),
1580 "RfcTimestamp::parse must accept {good:?}"
1581 );
1582 }
1583 }
1584
1585 #[test]
1586 fn api_key_entry_deserialize_rejects_malformed_expires_at() {
1587 let toml = r#"
1592 name = "bad-key"
1593 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1594 role = "viewer"
1595 expires_at = "not-a-date"
1596 "#;
1597 let result: Result<ApiKeyEntry, _> = toml::from_str(toml);
1598 assert!(
1599 result.is_err(),
1600 "deserialization must reject malformed expires_at"
1601 );
1602 }
1603
1604 #[test]
1605 fn api_key_entry_deserialize_accepts_valid_expires_at() {
1606 let toml = r#"
1607 name = "good-key"
1608 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1609 role = "viewer"
1610 expires_at = "2099-01-01T00:00:00Z"
1611 "#;
1612 let entry: ApiKeyEntry = toml::from_str(toml).expect("valid RFC 3339 must deserialize");
1613 assert!(entry.expires_at.is_some());
1614 }
1615
1616 #[test]
1617 fn api_key_entry_deserialize_accepts_missing_expires_at() {
1618 let toml = r#"
1621 name = "eternal-key"
1622 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1623 role = "viewer"
1624 "#;
1625 let entry: ApiKeyEntry = toml::from_str(toml).expect("missing expires_at must deserialize");
1626 assert!(entry.expires_at.is_none());
1627 }
1628
1629 #[test]
1630 fn try_with_expiry_rejects_malformed() {
1631 let entry = ApiKeyEntry::new("k", "hash", "viewer");
1632 assert!(entry.try_with_expiry("not-a-date").is_err());
1633 }
1634
1635 #[test]
1636 fn try_with_expiry_accepts_valid() {
1637 let entry = ApiKeyEntry::new("k", "hash", "viewer")
1638 .try_with_expiry("2099-01-01T00:00:00Z")
1639 .expect("valid RFC 3339 must be accepted");
1640 assert!(entry.expires_at.is_some());
1641 }
1642
1643 #[test]
1644 fn api_key_summary_serializes_expires_at_as_rfc3339() {
1645 let summary = ApiKeySummary {
1650 name: "k".into(),
1651 role: "viewer".into(),
1652 expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
1653 };
1654 let json = serde_json::to_string(&summary).unwrap();
1655 assert!(
1656 json.contains(r#""expires_at":"2030-01-01T00:00:00+00:00""#),
1657 "wire format regressed: {json}"
1658 );
1659 }
1660
1661 #[test]
1662 fn future_expiry_accepted() {
1663 let (token, hash) = generate_api_key().unwrap();
1664 let keys = vec![ApiKeyEntry {
1665 name: "test".into(),
1666 hash,
1667 role: "viewer".into(),
1668 expires_at: Some(RfcTimestamp::parse("2099-01-01T00:00:00Z").unwrap()),
1669 }];
1670 assert!(verify_bearer_token(&token, &keys).is_some());
1671 }
1672
1673 #[test]
1674 fn multiple_keys_first_match_wins() {
1675 let (token, hash) = generate_api_key().unwrap();
1676 let keys = vec![
1677 ApiKeyEntry {
1678 name: "wrong".into(),
1679 hash: "$argon2id$v=19$m=19456,t=2,p=1$invalid$invalid".into(),
1680 role: "ops".into(),
1681 expires_at: None,
1682 },
1683 ApiKeyEntry {
1684 name: "correct".into(),
1685 hash,
1686 role: "deploy".into(),
1687 expires_at: None,
1688 },
1689 ];
1690 let id = verify_bearer_token(&token, &keys).unwrap();
1691 assert_eq!(id.name, "correct");
1692 assert_eq!(id.role, "deploy");
1693 }
1694
1695 #[test]
1696 fn rate_limiter_allows_within_quota() {
1697 let config = RateLimitConfig {
1698 max_attempts_per_minute: 5,
1699 pre_auth_max_per_minute: None,
1700 max_tracked_keys: default_max_tracked_keys(),
1701 idle_eviction: default_idle_eviction(),
1702 };
1703 let limiter = build_rate_limiter(&config);
1704 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1705
1706 for _ in 0..5 {
1708 assert!(limiter.check_key(&ip).is_ok());
1709 }
1710 assert!(limiter.check_key(&ip).is_err());
1712 }
1713
1714 #[test]
1715 fn rate_limiter_separate_ips() {
1716 let config = RateLimitConfig {
1717 max_attempts_per_minute: 2,
1718 pre_auth_max_per_minute: None,
1719 max_tracked_keys: default_max_tracked_keys(),
1720 idle_eviction: default_idle_eviction(),
1721 };
1722 let limiter = build_rate_limiter(&config);
1723 let ip1: IpAddr = "10.0.0.1".parse().unwrap();
1724 let ip2: IpAddr = "10.0.0.2".parse().unwrap();
1725
1726 assert!(limiter.check_key(&ip1).is_ok());
1728 assert!(limiter.check_key(&ip1).is_ok());
1729 assert!(limiter.check_key(&ip1).is_err());
1730
1731 assert!(limiter.check_key(&ip2).is_ok());
1733 }
1734
1735 #[test]
1736 fn extract_mtls_identity_from_cn() {
1737 let mut params = rcgen::CertificateParams::new(vec!["test-client.local".into()]).unwrap();
1739 params.distinguished_name = rcgen::DistinguishedName::new();
1740 params
1741 .distinguished_name
1742 .push(rcgen::DnType::CommonName, "test-client");
1743 let cert = params
1744 .self_signed(&rcgen::KeyPair::generate().unwrap())
1745 .unwrap();
1746 let der = cert.der();
1747
1748 let id = extract_mtls_identity(der, "ops").unwrap();
1749 assert_eq!(id.name, "test-client");
1750 assert_eq!(id.role, "ops");
1751 assert_eq!(id.method, AuthMethod::MtlsCertificate);
1752 }
1753
1754 #[test]
1755 fn extract_mtls_identity_falls_back_to_san() {
1756 let mut params =
1758 rcgen::CertificateParams::new(vec!["san-only.example.com".into()]).unwrap();
1759 params.distinguished_name = rcgen::DistinguishedName::new();
1760 let cert = params
1762 .self_signed(&rcgen::KeyPair::generate().unwrap())
1763 .unwrap();
1764 let der = cert.der();
1765
1766 let id = extract_mtls_identity(der, "viewer").unwrap();
1767 assert_eq!(id.name, "san-only.example.com");
1768 assert_eq!(id.role, "viewer");
1769 }
1770
1771 #[test]
1772 fn extract_mtls_identity_invalid_der() {
1773 assert!(extract_mtls_identity(b"not-a-cert", "viewer").is_none());
1774 }
1775
1776 use axum::{
1779 body::Body,
1780 http::{Request, StatusCode},
1781 };
1782 use tower::ServiceExt as _;
1783
1784 fn auth_router(state: Arc<AuthState>) -> axum::Router {
1785 axum::Router::new()
1786 .route("/mcp", axum::routing::post(|| async { "ok" }))
1787 .layer(axum::middleware::from_fn(move |req, next| {
1788 let s = Arc::clone(&state);
1789 auth_middleware(s, req, next)
1790 }))
1791 }
1792
1793 fn test_auth_state(keys: Vec<ApiKeyEntry>) -> Arc<AuthState> {
1794 Arc::new(AuthState {
1795 api_keys: ArcSwap::new(Arc::new(keys)),
1796 rate_limiter: None,
1797 pre_auth_limiter: None,
1798 #[cfg(feature = "oauth")]
1799 jwks_cache: None,
1800 seen_identities: SeenIdentitySet::new(),
1801 counters: AuthCounters::default(),
1802 })
1803 }
1804
1805 #[tokio::test]
1806 async fn middleware_rejects_no_credentials() {
1807 let state = test_auth_state(vec![]);
1808 let app = auth_router(Arc::clone(&state));
1809 let req = Request::builder()
1810 .method(axum::http::Method::POST)
1811 .uri("/mcp")
1812 .body(Body::empty())
1813 .unwrap();
1814 let resp = app.oneshot(req).await.unwrap();
1815 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1816 let challenge = resp
1817 .headers()
1818 .get(header::WWW_AUTHENTICATE)
1819 .unwrap()
1820 .to_str()
1821 .unwrap();
1822 assert!(challenge.contains("error=\"invalid_request\""));
1823
1824 let counters = state.counters_snapshot();
1825 assert_eq!(counters.failure_missing_credential, 1);
1826 }
1827
1828 #[tokio::test]
1829 async fn middleware_accepts_valid_bearer() {
1830 let (token, hash) = generate_api_key().unwrap();
1831 let keys = vec![ApiKeyEntry {
1832 name: "test-key".into(),
1833 hash,
1834 role: "ops".into(),
1835 expires_at: None,
1836 }];
1837 let state = test_auth_state(keys);
1838 let app = auth_router(Arc::clone(&state));
1839 let req = Request::builder()
1840 .method(axum::http::Method::POST)
1841 .uri("/mcp")
1842 .header("authorization", format!("Bearer {token}"))
1843 .body(Body::empty())
1844 .unwrap();
1845 let resp = app.oneshot(req).await.unwrap();
1846 assert_eq!(resp.status(), StatusCode::OK);
1847
1848 let counters = state.counters_snapshot();
1849 assert_eq!(counters.success_bearer, 1);
1850 }
1851
1852 #[tokio::test]
1853 async fn middleware_rejects_wrong_bearer() {
1854 let (_token, hash) = generate_api_key().unwrap();
1855 let keys = vec![ApiKeyEntry {
1856 name: "test-key".into(),
1857 hash,
1858 role: "ops".into(),
1859 expires_at: None,
1860 }];
1861 let state = test_auth_state(keys);
1862 let app = auth_router(Arc::clone(&state));
1863 let req = Request::builder()
1864 .method(axum::http::Method::POST)
1865 .uri("/mcp")
1866 .header("authorization", "Bearer wrong-token-here")
1867 .body(Body::empty())
1868 .unwrap();
1869 let resp = app.oneshot(req).await.unwrap();
1870 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1871 let challenge = resp
1872 .headers()
1873 .get(header::WWW_AUTHENTICATE)
1874 .unwrap()
1875 .to_str()
1876 .unwrap();
1877 assert!(challenge.contains("error=\"invalid_token\""));
1878
1879 let counters = state.counters_snapshot();
1880 assert_eq!(counters.failure_invalid_credential, 1);
1881 }
1882
1883 #[tokio::test]
1884 async fn middleware_rate_limits() {
1885 let state = Arc::new(AuthState {
1886 api_keys: ArcSwap::new(Arc::new(vec![])),
1887 rate_limiter: Some(build_rate_limiter(&RateLimitConfig {
1888 max_attempts_per_minute: 1,
1889 pre_auth_max_per_minute: None,
1890 max_tracked_keys: default_max_tracked_keys(),
1891 idle_eviction: default_idle_eviction(),
1892 })),
1893 pre_auth_limiter: None,
1894 #[cfg(feature = "oauth")]
1895 jwks_cache: None,
1896 seen_identities: SeenIdentitySet::new(),
1897 counters: AuthCounters::default(),
1898 });
1899 let app = auth_router(state);
1900
1901 let req = Request::builder()
1903 .method(axum::http::Method::POST)
1904 .uri("/mcp")
1905 .body(Body::empty())
1906 .unwrap();
1907 let resp = app.clone().oneshot(req).await.unwrap();
1908 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1909
1910 }
1915
1916 #[test]
1922 fn rate_limit_semantics_failed_only() {
1923 let config = RateLimitConfig {
1924 max_attempts_per_minute: 3,
1925 pre_auth_max_per_minute: None,
1926 max_tracked_keys: default_max_tracked_keys(),
1927 idle_eviction: default_idle_eviction(),
1928 };
1929 let limiter = build_rate_limiter(&config);
1930 let ip: IpAddr = "192.168.1.100".parse().unwrap();
1931
1932 assert!(
1934 limiter.check_key(&ip).is_ok(),
1935 "failure 1 should be allowed"
1936 );
1937 assert!(
1938 limiter.check_key(&ip).is_ok(),
1939 "failure 2 should be allowed"
1940 );
1941 assert!(
1942 limiter.check_key(&ip).is_ok(),
1943 "failure 3 should be allowed"
1944 );
1945 assert!(
1946 limiter.check_key(&ip).is_err(),
1947 "failure 4 should be blocked"
1948 );
1949
1950 }
1959
1960 #[test]
1965 fn pre_auth_default_multiplier_is_10x() {
1966 let config = RateLimitConfig {
1967 max_attempts_per_minute: 5,
1968 pre_auth_max_per_minute: None,
1969 max_tracked_keys: default_max_tracked_keys(),
1970 idle_eviction: default_idle_eviction(),
1971 };
1972 let limiter = build_pre_auth_limiter(&config);
1973 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1974
1975 for i in 0..50 {
1977 assert!(
1978 limiter.check_key(&ip).is_ok(),
1979 "pre-auth attempt {i} (of expected 50) should be allowed under default 10x multiplier"
1980 );
1981 }
1982 assert!(
1984 limiter.check_key(&ip).is_err(),
1985 "pre-auth attempt 51 should be blocked (quota is 50, not unbounded)"
1986 );
1987 }
1988
1989 #[test]
1992 fn pre_auth_explicit_override_wins() {
1993 let config = RateLimitConfig {
1994 max_attempts_per_minute: 100, pre_auth_max_per_minute: Some(2), max_tracked_keys: default_max_tracked_keys(),
1997 idle_eviction: default_idle_eviction(),
1998 };
1999 let limiter = build_pre_auth_limiter(&config);
2000 let ip: IpAddr = "10.0.0.2".parse().unwrap();
2001
2002 assert!(limiter.check_key(&ip).is_ok(), "attempt 1 allowed");
2003 assert!(limiter.check_key(&ip).is_ok(), "attempt 2 allowed");
2004 assert!(
2005 limiter.check_key(&ip).is_err(),
2006 "attempt 3 must be blocked (explicit override of 2 wins over 10x default of 1000)"
2007 );
2008 }
2009
2010 #[tokio::test]
2016 async fn pre_auth_gate_blocks_before_argon2_verification() {
2017 let (_token, hash) = generate_api_key().unwrap();
2018 let keys = vec![ApiKeyEntry {
2019 name: "test-key".into(),
2020 hash,
2021 role: "ops".into(),
2022 expires_at: None,
2023 }];
2024 let config = RateLimitConfig {
2025 max_attempts_per_minute: 100,
2026 pre_auth_max_per_minute: Some(1),
2027 max_tracked_keys: default_max_tracked_keys(),
2028 idle_eviction: default_idle_eviction(),
2029 };
2030 let state = Arc::new(AuthState {
2031 api_keys: ArcSwap::new(Arc::new(keys)),
2032 rate_limiter: None,
2033 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
2034 #[cfg(feature = "oauth")]
2035 jwks_cache: None,
2036 seen_identities: SeenIdentitySet::new(),
2037 counters: AuthCounters::default(),
2038 });
2039 let app = auth_router(Arc::clone(&state));
2040 let peer: SocketAddr = "10.0.0.10:54321".parse().unwrap();
2041
2042 let mut req1 = Request::builder()
2045 .method(axum::http::Method::POST)
2046 .uri("/mcp")
2047 .header("authorization", "Bearer obviously-not-a-real-token")
2048 .body(Body::empty())
2049 .unwrap();
2050 req1.extensions_mut().insert(ConnectInfo(peer));
2051 let resp1 = app.clone().oneshot(req1).await.unwrap();
2052 assert_eq!(
2053 resp1.status(),
2054 StatusCode::UNAUTHORIZED,
2055 "first attempt: gate has quota, falls through to bearer auth which fails with 401"
2056 );
2057
2058 let mut req2 = Request::builder()
2061 .method(axum::http::Method::POST)
2062 .uri("/mcp")
2063 .header("authorization", "Bearer also-not-a-real-token")
2064 .body(Body::empty())
2065 .unwrap();
2066 req2.extensions_mut().insert(ConnectInfo(peer));
2067 let resp2 = app.oneshot(req2).await.unwrap();
2068 assert_eq!(
2069 resp2.status(),
2070 StatusCode::TOO_MANY_REQUESTS,
2071 "second attempt from same IP: pre-auth gate must reject with 429"
2072 );
2073
2074 let counters = state.counters_snapshot();
2075 assert_eq!(
2076 counters.failure_pre_auth_gate, 1,
2077 "exactly one request must have been rejected by the pre-auth gate"
2078 );
2079 assert_eq!(
2083 counters.failure_invalid_credential, 1,
2084 "bearer verification must run exactly once (only the un-gated first request)"
2085 );
2086 }
2087
2088 #[tokio::test]
2095 async fn pre_auth_gate_does_not_throttle_mtls() {
2096 let config = RateLimitConfig {
2097 max_attempts_per_minute: 100,
2098 pre_auth_max_per_minute: Some(1), max_tracked_keys: default_max_tracked_keys(),
2100 idle_eviction: default_idle_eviction(),
2101 };
2102 let state = Arc::new(AuthState {
2103 api_keys: ArcSwap::new(Arc::new(vec![])),
2104 rate_limiter: None,
2105 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
2106 #[cfg(feature = "oauth")]
2107 jwks_cache: None,
2108 seen_identities: SeenIdentitySet::new(),
2109 counters: AuthCounters::default(),
2110 });
2111 let app = auth_router(Arc::clone(&state));
2112 let peer: SocketAddr = "10.0.0.20:54321".parse().unwrap();
2113 let identity = AuthIdentity {
2114 name: "cn=test-client".into(),
2115 role: "viewer".into(),
2116 method: AuthMethod::MtlsCertificate,
2117 raw_token: None,
2118 sub: None,
2119 };
2120 let tls_info = TlsConnInfo::new(peer, Some(identity));
2121
2122 for i in 0..3 {
2123 let mut req = Request::builder()
2124 .method(axum::http::Method::POST)
2125 .uri("/mcp")
2126 .body(Body::empty())
2127 .unwrap();
2128 req.extensions_mut().insert(ConnectInfo(tls_info.clone()));
2129 let resp = app.clone().oneshot(req).await.unwrap();
2130 assert_eq!(
2131 resp.status(),
2132 StatusCode::OK,
2133 "mTLS request {i} must succeed: pre-auth gate must not apply to mTLS callers"
2134 );
2135 }
2136
2137 let counters = state.counters_snapshot();
2138 assert_eq!(
2139 counters.failure_pre_auth_gate, 0,
2140 "pre-auth gate counter must remain at zero: mTLS bypasses the gate"
2141 );
2142 assert_eq!(
2143 counters.success_mtls, 3,
2144 "all three mTLS requests must have been counted as successful"
2145 );
2146 }
2147
2148 #[test]
2153 fn extract_bearer_accepts_canonical_case() {
2154 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2155 }
2156
2157 #[test]
2158 fn extract_bearer_is_case_insensitive_per_rfc7235() {
2159 for header in &[
2163 "bearer abc123",
2164 "BEARER abc123",
2165 "BeArEr abc123",
2166 "bEaReR abc123",
2167 ] {
2168 assert_eq!(
2169 extract_bearer(header),
2170 Some("abc123"),
2171 "header {header:?} must parse as a Bearer token (RFC 7235 §2.1)"
2172 );
2173 }
2174 }
2175
2176 #[test]
2177 fn extract_bearer_rejects_other_schemes() {
2178 assert_eq!(extract_bearer("Basic dXNlcjpwYXNz"), None);
2179 assert_eq!(extract_bearer("Digest username=\"x\""), None);
2180 assert_eq!(extract_bearer("Token abc123"), None);
2181 }
2182
2183 #[test]
2184 fn extract_bearer_rejects_malformed() {
2185 assert_eq!(extract_bearer(""), None);
2187 assert_eq!(extract_bearer("Bearer"), None);
2188 assert_eq!(extract_bearer("Bearer "), None);
2189 assert_eq!(extract_bearer("Bearer "), None);
2190 }
2191
2192 #[test]
2193 fn extract_bearer_tolerates_extra_separator_whitespace() {
2194 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2196 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2197 }
2198
2199 #[test]
2205 fn auth_identity_debug_redacts_raw_token() {
2206 let id = AuthIdentity {
2207 name: "alice".into(),
2208 role: "admin".into(),
2209 method: AuthMethod::OAuthJwt,
2210 raw_token: Some(SecretString::from("super-secret-jwt-payload-xyz")),
2211 sub: Some("keycloak-uuid-2f3c8b".into()),
2212 };
2213 let dbg = format!("{id:?}");
2214
2215 assert!(dbg.contains("alice"), "name should be visible: {dbg}");
2217 assert!(dbg.contains("admin"), "role should be visible: {dbg}");
2218 assert!(dbg.contains("OAuthJwt"), "method should be visible: {dbg}");
2219
2220 assert!(
2222 !dbg.contains("super-secret-jwt-payload-xyz"),
2223 "raw_token must be redacted in Debug output: {dbg}"
2224 );
2225 assert!(
2226 !dbg.contains("keycloak-uuid-2f3c8b"),
2227 "sub must be redacted in Debug output: {dbg}"
2228 );
2229 assert!(
2230 dbg.contains("<redacted>"),
2231 "redaction marker missing: {dbg}"
2232 );
2233 }
2234
2235 #[test]
2236 fn auth_identity_debug_marks_absent_secrets() {
2237 let id = AuthIdentity {
2240 name: "viewer-key".into(),
2241 role: "viewer".into(),
2242 method: AuthMethod::BearerToken,
2243 raw_token: None,
2244 sub: None,
2245 };
2246 let dbg = format!("{id:?}");
2247 assert!(
2248 dbg.contains("<none>"),
2249 "absent secrets should be marked: {dbg}"
2250 );
2251 assert!(
2252 !dbg.contains("<redacted>"),
2253 "no <redacted> marker when secrets are absent: {dbg}"
2254 );
2255 }
2256
2257 #[test]
2258 fn api_key_entry_debug_redacts_hash() {
2259 let entry = ApiKeyEntry {
2260 name: "viewer-key".into(),
2261 hash: "$argon2id$v=19$m=19456,t=2,p=1$c2FsdHNhbHQ$h4sh3dPa55w0rd".into(),
2263 role: "viewer".into(),
2264 expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
2265 };
2266 let dbg = format!("{entry:?}");
2267
2268 assert!(dbg.contains("viewer-key"));
2270 assert!(dbg.contains("viewer"));
2271 assert!(dbg.contains("2030-01-01T00:00:00+00:00"));
2272
2273 assert!(
2275 !dbg.contains("$argon2id$"),
2276 "argon2 hash leaked into Debug output: {dbg}"
2277 );
2278 assert!(
2279 !dbg.contains("h4sh3dPa55w0rd"),
2280 "hash digest leaked into Debug output: {dbg}"
2281 );
2282 assert!(
2283 dbg.contains("<redacted>"),
2284 "redaction marker missing: {dbg}"
2285 );
2286 }
2287
2288 #[test]
2299 fn auth_failure_class_as_str_exact_strings() {
2300 assert_eq!(
2301 AuthFailureClass::MissingCredential.as_str(),
2302 "missing_credential"
2303 );
2304 assert_eq!(
2305 AuthFailureClass::InvalidCredential.as_str(),
2306 "invalid_credential"
2307 );
2308 assert_eq!(
2309 AuthFailureClass::ExpiredCredential.as_str(),
2310 "expired_credential"
2311 );
2312 assert_eq!(AuthFailureClass::RateLimited.as_str(), "rate_limited");
2313 assert_eq!(AuthFailureClass::PreAuthGate.as_str(), "pre_auth_gate");
2314 }
2315
2316 #[test]
2317 fn auth_failure_class_response_body_exact_strings() {
2318 assert_eq!(
2319 AuthFailureClass::MissingCredential.response_body(),
2320 "unauthorized: missing credential"
2321 );
2322 assert_eq!(
2323 AuthFailureClass::InvalidCredential.response_body(),
2324 "unauthorized: invalid credential"
2325 );
2326 assert_eq!(
2327 AuthFailureClass::ExpiredCredential.response_body(),
2328 "unauthorized: expired credential"
2329 );
2330 assert_eq!(
2331 AuthFailureClass::RateLimited.response_body(),
2332 "rate limited"
2333 );
2334 assert_eq!(
2335 AuthFailureClass::PreAuthGate.response_body(),
2336 "rate limited (pre-auth)"
2337 );
2338 }
2339
2340 #[test]
2341 fn auth_failure_class_bearer_error_exact_strings() {
2342 assert_eq!(
2343 AuthFailureClass::MissingCredential.bearer_error(),
2344 (
2345 "invalid_request",
2346 "missing bearer token or mTLS client certificate"
2347 )
2348 );
2349 assert_eq!(
2350 AuthFailureClass::InvalidCredential.bearer_error(),
2351 ("invalid_token", "token is invalid")
2352 );
2353 assert_eq!(
2354 AuthFailureClass::ExpiredCredential.bearer_error(),
2355 ("invalid_token", "token is expired")
2356 );
2357 assert_eq!(
2358 AuthFailureClass::RateLimited.bearer_error(),
2359 ("invalid_request", "too many failed authentication attempts")
2360 );
2361 assert_eq!(
2362 AuthFailureClass::PreAuthGate.bearer_error(),
2363 (
2364 "invalid_request",
2365 "too many unauthenticated requests from this source"
2366 )
2367 );
2368 }
2369
2370 #[test]
2379 fn auth_config_summary_bearer_true_when_keys_present() {
2380 let (_token, hash) = generate_api_key().unwrap();
2381 let cfg = AuthConfig::with_keys(vec![ApiKeyEntry::new("k", hash, "viewer")]);
2382 let s = cfg.summary();
2383 assert!(s.enabled, "summary.enabled must reflect AuthConfig.enabled");
2384 assert!(
2385 s.bearer,
2386 "summary.bearer must be true when api_keys is non-empty (kills `!` deletion at L615)"
2387 );
2388 assert!(!s.mtls, "summary.mtls must be false when mtls is None");
2389 assert!(!s.oauth, "summary.oauth must be false when oauth is None");
2390 assert_eq!(s.api_keys.len(), 1);
2391 assert_eq!(s.api_keys[0].name, "k");
2392 assert_eq!(s.api_keys[0].role, "viewer");
2393 }
2394
2395 #[test]
2396 fn auth_config_summary_bearer_false_when_no_keys() {
2397 let cfg = AuthConfig::with_keys(vec![]);
2398 let s = cfg.summary();
2399 assert!(
2400 !s.bearer,
2401 "summary.bearer must be false when api_keys is empty (kills `!` deletion at L615)"
2402 );
2403 assert!(s.api_keys.is_empty());
2404 }
2405
2406 #[test]
2407 fn seen_identity_set_first_then_repeat() {
2408 let set = SeenIdentitySet::new();
2409 assert!(set.insert_is_first("alice"), "first sighting is first");
2410 assert!(
2411 !set.insert_is_first("alice"),
2412 "second sighting is not first"
2413 );
2414 assert!(set.insert_is_first("bob"));
2415 assert_eq!(set.len(), 2);
2416 }
2417
2418 #[test]
2419 fn seen_identity_set_evicts_oldest_at_cap() {
2420 let set = SeenIdentitySet::with_cap(2);
2421 assert!(set.insert_is_first("a"));
2422 assert!(set.insert_is_first("b"));
2423 assert!(set.insert_is_first("c"));
2425 assert_eq!(set.len(), 2);
2426 assert!(set.insert_is_first("a"));
2430 assert_eq!(set.len(), 2);
2431 assert!(set.insert_is_first("b"));
2433 for i in 0..32 {
2435 set.insert_is_first(&format!("churn-{i}"));
2436 assert!(set.len() <= 2, "cap invariant must hold");
2437 }
2438 }
2439
2440 #[test]
2441 fn seen_identity_set_cap_zero_is_raised_to_one() {
2442 let set = SeenIdentitySet::with_cap(0);
2443 assert!(set.insert_is_first("only"));
2444 assert_eq!(set.len(), 1);
2445 assert!(set.insert_is_first("next"));
2447 assert_eq!(set.len(), 1);
2448 }
2449
2450 #[test]
2451 fn seen_identity_set_fifo_does_not_refresh_on_repeat_hit() {
2452 let set = SeenIdentitySet::with_cap(2);
2455 assert!(set.insert_is_first("a")); assert!(set.insert_is_first("b")); assert!(!set.insert_is_first("a"));
2461 assert!(set.insert_is_first("c"));
2464 assert!(set.insert_is_first("a"));
2466 let set = SeenIdentitySet::with_cap(2);
2472 assert!(set.insert_is_first("x")); assert!(set.insert_is_first("y")); assert!(!set.insert_is_first("x")); assert!(set.insert_is_first("z")); assert!(
2477 !set.insert_is_first("y"),
2478 "y must still be present (FIFO did not evict it)"
2479 );
2480 assert!(
2481 set.insert_is_first("x"),
2482 "x must have been evicted by FIFO (would NOT have been evicted under LRU)"
2483 );
2484 }
2485}