1use std::{
10 collections::HashSet,
11 net::{IpAddr, SocketAddr},
12 num::NonZeroU32,
13 path::PathBuf,
14 sync::{
15 Arc, LazyLock, Mutex,
16 atomic::{AtomicU64, Ordering},
17 },
18 time::Duration,
19};
20
21use arc_swap::ArcSwap;
22use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
23use axum::{
24 body::Body,
25 extract::ConnectInfo,
26 http::{Request, header},
27 middleware::Next,
28 response::{IntoResponse, Response},
29};
30use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
31use secrecy::SecretString;
32use serde::Deserialize;
33use x509_parser::prelude::*;
34
35use crate::{bounded_limiter::BoundedKeyedLimiter, error::McpxError};
36
37#[derive(Clone)]
46#[non_exhaustive]
47pub struct AuthIdentity {
48 pub name: String,
50 pub role: String,
52 pub method: AuthMethod,
54 pub raw_token: Option<SecretString>,
60 pub sub: Option<String>,
63}
64
65impl std::fmt::Debug for AuthIdentity {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("AuthIdentity")
70 .field("name", &self.name)
71 .field("role", &self.role)
72 .field("method", &self.method)
73 .field(
74 "raw_token",
75 &if self.raw_token.is_some() {
76 "<redacted>"
77 } else {
78 "<none>"
79 },
80 )
81 .field(
82 "sub",
83 &if self.sub.is_some() {
84 "<redacted>"
85 } else {
86 "<none>"
87 },
88 )
89 .finish()
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95#[non_exhaustive]
96pub enum AuthMethod {
97 BearerToken,
99 MtlsCertificate,
101 OAuthJwt,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106enum AuthFailureClass {
107 MissingCredential,
108 InvalidCredential,
109 #[cfg_attr(not(feature = "oauth"), allow(dead_code))]
110 ExpiredCredential,
111 RateLimited,
113 PreAuthGate,
116}
117
118impl AuthFailureClass {
119 fn as_str(self) -> &'static str {
120 match self {
121 Self::MissingCredential => "missing_credential",
122 Self::InvalidCredential => "invalid_credential",
123 Self::ExpiredCredential => "expired_credential",
124 Self::RateLimited => "rate_limited",
125 Self::PreAuthGate => "pre_auth_gate",
126 }
127 }
128
129 fn bearer_error(self) -> (&'static str, &'static str) {
130 match self {
131 Self::MissingCredential => (
132 "invalid_request",
133 "missing bearer token or mTLS client certificate",
134 ),
135 Self::InvalidCredential => ("invalid_token", "token is invalid"),
136 Self::ExpiredCredential => ("invalid_token", "token is expired"),
137 Self::RateLimited => ("invalid_request", "too many failed authentication attempts"),
138 Self::PreAuthGate => (
139 "invalid_request",
140 "too many unauthenticated requests from this source",
141 ),
142 }
143 }
144
145 fn response_body(self) -> &'static str {
146 match self {
147 Self::MissingCredential => "unauthorized: missing credential",
148 Self::InvalidCredential => "unauthorized: invalid credential",
149 Self::ExpiredCredential => "unauthorized: expired credential",
150 Self::RateLimited => "rate limited",
151 Self::PreAuthGate => "rate limited (pre-auth)",
152 }
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
158#[non_exhaustive]
159pub struct AuthCountersSnapshot {
160 pub success_mtls: u64,
162 pub success_bearer: u64,
164 pub success_oauth_jwt: u64,
166 pub failure_missing_credential: u64,
168 pub failure_invalid_credential: u64,
170 pub failure_expired_credential: u64,
172 pub failure_rate_limited: u64,
174 pub failure_pre_auth_gate: u64,
177}
178
179#[derive(Debug, Default)]
181pub(crate) struct AuthCounters {
182 success_mtls: AtomicU64,
183 success_bearer: AtomicU64,
184 success_oauth_jwt: AtomicU64,
185 failure_missing_credential: AtomicU64,
186 failure_invalid_credential: AtomicU64,
187 failure_expired_credential: AtomicU64,
188 failure_rate_limited: AtomicU64,
189 failure_pre_auth_gate: AtomicU64,
190}
191
192impl AuthCounters {
193 fn record_success(&self, method: AuthMethod) {
194 match method {
195 AuthMethod::MtlsCertificate => {
196 self.success_mtls.fetch_add(1, Ordering::Relaxed);
197 }
198 AuthMethod::BearerToken => {
199 self.success_bearer.fetch_add(1, Ordering::Relaxed);
200 }
201 AuthMethod::OAuthJwt => {
202 self.success_oauth_jwt.fetch_add(1, Ordering::Relaxed);
203 }
204 }
205 }
206
207 fn record_failure(&self, class: AuthFailureClass) {
208 match class {
209 AuthFailureClass::MissingCredential => {
210 self.failure_missing_credential
211 .fetch_add(1, Ordering::Relaxed);
212 }
213 AuthFailureClass::InvalidCredential => {
214 self.failure_invalid_credential
215 .fetch_add(1, Ordering::Relaxed);
216 }
217 AuthFailureClass::ExpiredCredential => {
218 self.failure_expired_credential
219 .fetch_add(1, Ordering::Relaxed);
220 }
221 AuthFailureClass::RateLimited => {
222 self.failure_rate_limited.fetch_add(1, Ordering::Relaxed);
223 }
224 AuthFailureClass::PreAuthGate => {
225 self.failure_pre_auth_gate.fetch_add(1, Ordering::Relaxed);
226 }
227 }
228 }
229
230 fn snapshot(&self) -> AuthCountersSnapshot {
231 AuthCountersSnapshot {
232 success_mtls: self.success_mtls.load(Ordering::Relaxed),
233 success_bearer: self.success_bearer.load(Ordering::Relaxed),
234 success_oauth_jwt: self.success_oauth_jwt.load(Ordering::Relaxed),
235 failure_missing_credential: self.failure_missing_credential.load(Ordering::Relaxed),
236 failure_invalid_credential: self.failure_invalid_credential.load(Ordering::Relaxed),
237 failure_expired_credential: self.failure_expired_credential.load(Ordering::Relaxed),
238 failure_rate_limited: self.failure_rate_limited.load(Ordering::Relaxed),
239 failure_pre_auth_gate: self.failure_pre_auth_gate.load(Ordering::Relaxed),
240 }
241 }
242}
243
244#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
256#[non_exhaustive]
257pub struct RfcTimestamp(chrono::DateTime<chrono::FixedOffset>);
258
259impl RfcTimestamp {
260 pub fn parse(s: &str) -> Result<Self, chrono::ParseError> {
268 chrono::DateTime::parse_from_rfc3339(s).map(Self)
269 }
270
271 #[must_use]
273 pub fn as_datetime(&self) -> &chrono::DateTime<chrono::FixedOffset> {
274 &self.0
275 }
276
277 #[must_use]
279 pub fn into_inner(self) -> chrono::DateTime<chrono::FixedOffset> {
280 self.0
281 }
282}
283
284impl std::fmt::Display for RfcTimestamp {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 write!(f, "{}", self.0.to_rfc3339())
288 }
289}
290
291impl std::fmt::Debug for RfcTimestamp {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 write!(f, "{}", self.0.to_rfc3339())
298 }
299}
300
301impl<'de> Deserialize<'de> for RfcTimestamp {
302 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
303 where
304 D: serde::Deserializer<'de>,
305 {
306 let s = String::deserialize(deserializer)?;
310 Self::parse(&s).map_err(serde::de::Error::custom)
311 }
312}
313
314impl serde::Serialize for RfcTimestamp {
315 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
316 where
317 S: serde::Serializer,
318 {
319 serializer.serialize_str(&self.0.to_rfc3339())
320 }
321}
322
323impl From<chrono::DateTime<chrono::FixedOffset>> for RfcTimestamp {
324 fn from(value: chrono::DateTime<chrono::FixedOffset>) -> Self {
325 Self(value)
326 }
327}
328
329#[derive(Clone, Deserialize)]
336#[non_exhaustive]
337pub struct ApiKeyEntry {
338 pub name: String,
340 pub hash: String,
342 pub role: String,
344 pub expires_at: Option<RfcTimestamp>,
349}
350
351impl std::fmt::Debug for ApiKeyEntry {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 f.debug_struct("ApiKeyEntry")
356 .field("name", &self.name)
357 .field("hash", &"<redacted>")
358 .field("role", &self.role)
359 .field("expires_at", &self.expires_at)
360 .finish()
361 }
362}
363
364impl ApiKeyEntry {
365 #[must_use]
367 pub fn new(name: impl Into<String>, hash: impl Into<String>, role: impl Into<String>) -> Self {
368 Self {
369 name: name.into(),
370 hash: hash.into(),
371 role: role.into(),
372 expires_at: None,
373 }
374 }
375
376 #[must_use]
381 pub fn with_expiry(mut self, expires_at: RfcTimestamp) -> Self {
382 self.expires_at = Some(expires_at);
383 self
384 }
385
386 pub fn try_with_expiry(
394 mut self,
395 expires_at: impl AsRef<str>,
396 ) -> Result<Self, chrono::ParseError> {
397 self.expires_at = Some(RfcTimestamp::parse(expires_at.as_ref())?);
398 Ok(self)
399 }
400}
401
402#[derive(Debug, Clone, Deserialize)]
404#[allow(
405 clippy::struct_excessive_bools,
406 reason = "mTLS CRL behavior is intentionally configured as independent booleans"
407)]
408#[non_exhaustive]
409pub struct MtlsConfig {
410 pub ca_cert_path: PathBuf,
412 #[serde(default)]
415 pub required: bool,
416 #[serde(default = "default_mtls_role")]
419 pub default_role: String,
420 #[serde(default = "default_true")]
423 pub crl_enabled: bool,
424 #[serde(default, with = "humantime_serde::option")]
427 pub crl_refresh_interval: Option<Duration>,
428 #[serde(default = "default_crl_fetch_timeout", with = "humantime_serde")]
430 pub crl_fetch_timeout: Duration,
431 #[serde(default = "default_crl_stale_grace", with = "humantime_serde")]
434 pub crl_stale_grace: Duration,
435 #[serde(default)]
438 pub crl_deny_on_unavailable: bool,
439 #[serde(default)]
441 pub crl_end_entity_only: bool,
442 #[serde(default = "default_true")]
451 pub crl_allow_http: bool,
452 #[serde(default = "default_true")]
454 pub crl_enforce_expiration: bool,
455 #[serde(default = "default_crl_max_concurrent_fetches")]
461 pub crl_max_concurrent_fetches: usize,
462 #[serde(default = "default_crl_max_response_bytes")]
466 pub crl_max_response_bytes: u64,
467 #[serde(default = "default_crl_discovery_rate_per_min")]
479 pub crl_discovery_rate_per_min: u32,
480 #[serde(default = "default_crl_max_host_semaphores")]
487 pub crl_max_host_semaphores: usize,
488 #[serde(default = "default_crl_max_seen_urls")]
492 pub crl_max_seen_urls: usize,
493 #[serde(default = "default_crl_max_cache_entries")]
497 pub crl_max_cache_entries: usize,
498}
499
500fn default_mtls_role() -> String {
501 "viewer".into()
502}
503
504const fn default_true() -> bool {
505 true
506}
507
508const fn default_crl_fetch_timeout() -> Duration {
509 Duration::from_secs(30)
510}
511
512const fn default_crl_stale_grace() -> Duration {
513 Duration::from_hours(24)
514}
515
516const fn default_crl_max_concurrent_fetches() -> usize {
517 4
518}
519
520const fn default_crl_max_response_bytes() -> u64 {
521 5 * 1024 * 1024
522}
523
524const fn default_crl_discovery_rate_per_min() -> u32 {
525 60
526}
527
528const fn default_crl_max_host_semaphores() -> usize {
529 1024
530}
531
532const fn default_crl_max_seen_urls() -> usize {
533 4096
534}
535
536const fn default_crl_max_cache_entries() -> usize {
537 1024
538}
539
540#[derive(Debug, Clone, Deserialize)]
555#[non_exhaustive]
556pub struct RateLimitConfig {
557 #[serde(default = "default_max_attempts")]
560 pub max_attempts_per_minute: u32,
561 #[serde(default)]
569 pub pre_auth_max_per_minute: Option<u32>,
570 #[serde(default = "default_max_tracked_keys")]
575 pub max_tracked_keys: usize,
576 #[serde(default = "default_idle_eviction", with = "humantime_serde")]
579 pub idle_eviction: Duration,
580}
581
582impl Default for RateLimitConfig {
583 fn default() -> Self {
584 Self {
585 max_attempts_per_minute: default_max_attempts(),
586 pre_auth_max_per_minute: None,
587 max_tracked_keys: default_max_tracked_keys(),
588 idle_eviction: default_idle_eviction(),
589 }
590 }
591}
592
593impl RateLimitConfig {
594 #[must_use]
598 pub fn new(max_attempts_per_minute: u32) -> Self {
599 Self {
600 max_attempts_per_minute,
601 ..Self::default()
602 }
603 }
604
605 #[must_use]
608 pub fn with_pre_auth_max_per_minute(mut self, quota: u32) -> Self {
609 self.pre_auth_max_per_minute = Some(quota);
610 self
611 }
612
613 #[must_use]
615 pub fn with_max_tracked_keys(mut self, max: usize) -> Self {
616 self.max_tracked_keys = max;
617 self
618 }
619
620 #[must_use]
622 pub fn with_idle_eviction(mut self, idle: Duration) -> Self {
623 self.idle_eviction = idle;
624 self
625 }
626}
627
628fn default_max_attempts() -> u32 {
629 30
630}
631
632fn default_max_tracked_keys() -> usize {
633 10_000
634}
635
636fn default_idle_eviction() -> Duration {
637 Duration::from_mins(15)
638}
639
640#[derive(Debug, Clone, Default, Deserialize)]
642#[non_exhaustive]
643pub struct AuthConfig {
644 #[serde(default)]
646 pub enabled: bool,
647 #[serde(default)]
649 pub api_keys: Vec<ApiKeyEntry>,
650 pub mtls: Option<MtlsConfig>,
652 pub rate_limit: Option<RateLimitConfig>,
654 #[cfg(feature = "oauth")]
656 pub oauth: Option<crate::oauth::OAuthConfig>,
657}
658
659impl AuthConfig {
660 #[must_use]
662 pub fn with_keys(keys: Vec<ApiKeyEntry>) -> Self {
663 Self {
664 enabled: true,
665 api_keys: keys,
666 mtls: None,
667 rate_limit: None,
668 #[cfg(feature = "oauth")]
669 oauth: None,
670 }
671 }
672
673 #[must_use]
675 pub fn with_rate_limit(mut self, rate_limit: RateLimitConfig) -> Self {
676 self.rate_limit = Some(rate_limit);
677 self
678 }
679}
680
681#[derive(Debug, Clone, serde::Serialize)]
685#[non_exhaustive]
686pub struct ApiKeySummary {
687 pub name: String,
689 pub role: String,
691 pub expires_at: Option<RfcTimestamp>,
694}
695
696#[derive(Debug, Clone, serde::Serialize)]
698#[allow(
699 clippy::struct_excessive_bools,
700 reason = "this is a flat summary of independent auth-method booleans"
701)]
702#[non_exhaustive]
703pub struct AuthConfigSummary {
704 pub enabled: bool,
706 pub bearer: bool,
708 pub mtls: bool,
710 pub oauth: bool,
712 pub api_keys: Vec<ApiKeySummary>,
714}
715
716impl AuthConfig {
717 #[must_use]
719 pub fn summary(&self) -> AuthConfigSummary {
720 AuthConfigSummary {
721 enabled: self.enabled,
722 bearer: !self.api_keys.is_empty(),
723 mtls: self.mtls.is_some(),
724 #[cfg(feature = "oauth")]
725 oauth: self.oauth.is_some(),
726 #[cfg(not(feature = "oauth"))]
727 oauth: false,
728 api_keys: self
729 .api_keys
730 .iter()
731 .map(|k| ApiKeySummary {
732 name: k.name.clone(),
733 role: k.role.clone(),
734 expires_at: k.expires_at,
735 })
736 .collect(),
737 }
738 }
739}
740
741pub(crate) type KeyedLimiter = BoundedKeyedLimiter<IpAddr>;
744
745#[derive(Clone, Debug)]
755#[non_exhaustive]
756pub(crate) struct TlsConnInfo {
757 pub addr: SocketAddr,
759 pub identity: Option<AuthIdentity>,
762}
763
764impl TlsConnInfo {
765 #[must_use]
767 pub(crate) const fn new(addr: SocketAddr, identity: Option<AuthIdentity>) -> Self {
768 Self { addr, identity }
769 }
770}
771
772#[allow(
777 missing_debug_implementations,
778 reason = "contains governor RateLimiter and JwksCache without Debug impls"
779)]
780#[non_exhaustive]
781pub(crate) struct AuthState {
782 pub api_keys: ArcSwap<Vec<ApiKeyEntry>>,
784 pub rate_limiter: Option<Arc<KeyedLimiter>>,
786 pub pre_auth_limiter: Option<Arc<KeyedLimiter>>,
789 #[cfg(feature = "oauth")]
790 pub jwks_cache: Option<Arc<crate::oauth::JwksCache>>,
792 pub seen_identities: Mutex<HashSet<String>>,
795 pub counters: AuthCounters,
797}
798
799impl AuthState {
800 pub(crate) fn reload_keys(&self, keys: Vec<ApiKeyEntry>) {
806 let count = keys.len();
807 self.api_keys.store(Arc::new(keys));
808 tracing::info!(keys = count, "API keys reloaded");
809 }
810
811 #[must_use]
813 pub(crate) fn counters_snapshot(&self) -> AuthCountersSnapshot {
814 self.counters.snapshot()
815 }
816
817 #[must_use]
819 pub(crate) fn api_key_summaries(&self) -> Vec<ApiKeySummary> {
820 self.api_keys
821 .load()
822 .iter()
823 .map(|k| ApiKeySummary {
824 name: k.name.clone(),
825 role: k.role.clone(),
826 expires_at: k.expires_at,
827 })
828 .collect()
829 }
830
831 fn log_auth(&self, id: &AuthIdentity, method: &str) {
833 self.counters.record_success(id.method);
834 let first = self
835 .seen_identities
836 .lock()
837 .unwrap_or_else(std::sync::PoisonError::into_inner)
838 .insert(id.name.clone());
839 if first {
840 tracing::info!(name = %id.name, role = %id.role, "{method} authenticated");
841 } else {
842 tracing::debug!(name = %id.name, role = %id.role, "{method} authenticated");
843 }
844 }
845}
846
847const DEFAULT_AUTH_RATE: NonZeroU32 = NonZeroU32::new(30).unwrap();
850
851#[must_use]
853pub(crate) fn build_rate_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
854 let quota = governor::Quota::per_minute(
855 NonZeroU32::new(config.max_attempts_per_minute).unwrap_or(DEFAULT_AUTH_RATE),
856 );
857 Arc::new(BoundedKeyedLimiter::new(
858 quota,
859 config.max_tracked_keys,
860 config.idle_eviction,
861 ))
862}
863
864#[must_use]
871pub(crate) fn build_pre_auth_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
872 let resolved = config.pre_auth_max_per_minute.unwrap_or_else(|| {
873 config
874 .max_attempts_per_minute
875 .saturating_mul(PRE_AUTH_DEFAULT_MULTIPLIER)
876 });
877 let quota =
878 governor::Quota::per_minute(NonZeroU32::new(resolved).unwrap_or(DEFAULT_PRE_AUTH_RATE));
879 Arc::new(BoundedKeyedLimiter::new(
880 quota,
881 config.max_tracked_keys,
882 config.idle_eviction,
883 ))
884}
885
886const PRE_AUTH_DEFAULT_MULTIPLIER: u32 = 10;
889
890const DEFAULT_PRE_AUTH_RATE: NonZeroU32 = NonZeroU32::new(300).unwrap();
894
895#[must_use]
900pub fn extract_mtls_identity(cert_der: &[u8], default_role: &str) -> Option<AuthIdentity> {
901 let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
902
903 let cn = cert
905 .subject()
906 .iter_common_name()
907 .next()
908 .and_then(|attr| attr.as_str().ok())
909 .map(String::from);
910
911 let name = cn.or_else(|| {
913 cert.subject_alternative_name()
914 .ok()
915 .flatten()
916 .and_then(|san| {
917 #[allow(clippy::wildcard_enum_match_arm)]
918 san.value.general_names.iter().find_map(|gn| match gn {
919 GeneralName::DNSName(dns) => Some((*dns).to_owned()),
920 _ => None,
921 })
922 })
923 })?;
924
925 if !name
927 .chars()
928 .all(|c| c.is_alphanumeric() || matches!(c, '-' | '.' | '_' | '@'))
929 {
930 tracing::warn!(cn = %name, "mTLS identity rejected: invalid characters in CN/SAN");
931 return None;
932 }
933
934 Some(AuthIdentity {
935 name,
936 role: default_role.to_owned(),
937 method: AuthMethod::MtlsCertificate,
938 raw_token: None,
939 sub: None,
940 })
941}
942
943fn extract_bearer(value: &str) -> Option<&str> {
958 let (scheme, rest) = value.split_once(' ')?;
959 if scheme.eq_ignore_ascii_case("Bearer") {
960 let token = rest.trim_start_matches(' ');
961 if token.is_empty() { None } else { Some(token) }
962 } else {
963 None
964 }
965}
966
967#[must_use]
996pub fn verify_bearer_token(token: &str, keys: &[ApiKeyEntry]) -> Option<AuthIdentity> {
997 use subtle::ConstantTimeEq as _;
998
999 let now = chrono::Utc::now();
1000 let dummy_hash = PasswordHash::new(&DUMMY_PHC_HASH)
1001 .expect("DUMMY_PHC_HASH is a valid Argon2id PHC string by construction");
1002
1003 let mut matched_index: usize = usize::MAX;
1004 let mut any_match: u8 = 0;
1005
1006 for (idx, key) in keys.iter().enumerate() {
1007 let expired = key.expires_at.is_some_and(|exp| exp.as_datetime() < &now);
1008
1009 let real_hash = PasswordHash::new(&key.hash);
1010 let verify_against = match (&real_hash, expired, any_match) {
1011 (Ok(h), false, 0) => h,
1012 _ => &dummy_hash,
1013 };
1014
1015 let slot_ok = u8::from(
1016 Argon2::default()
1017 .verify_password(token.as_bytes(), verify_against)
1018 .is_ok(),
1019 );
1020
1021 let real_match = slot_ok & u8::from(!expired) & u8::from(real_hash.is_ok());
1022 let first_real_match = real_match & (1 - any_match);
1023 if first_real_match.ct_eq(&1).into() {
1024 matched_index = idx;
1025 }
1026 any_match |= real_match;
1027 }
1028
1029 if any_match == 0 {
1030 return None;
1031 }
1032 let key = keys.get(matched_index)?;
1033 Some(AuthIdentity {
1034 name: key.name.clone(),
1035 role: key.role.clone(),
1036 method: AuthMethod::BearerToken,
1037 raw_token: None,
1038 sub: None,
1039 })
1040}
1041
1042static DUMMY_PHC_HASH: LazyLock<String> = LazyLock::new(|| {
1055 let salt = SaltString::from_b64("AAAAAAAAAAAAAAAAAAAAAA")
1057 .expect("fixed 16-byte base64 salt is well-formed");
1058 Argon2::default()
1059 .hash_password(b"rmcp-server-kit-dummy", &salt)
1060 .expect("Argon2 default params hash a fixed plaintext")
1061 .to_string()
1062});
1063
1064pub fn generate_api_key() -> Result<(String, String), McpxError> {
1074 let mut token_bytes = [0u8; 32];
1075 rand::fill(&mut token_bytes);
1076 let token = URL_SAFE_NO_PAD.encode(token_bytes);
1077
1078 let mut salt_bytes = [0u8; 16];
1080 rand::fill(&mut salt_bytes);
1081 let salt = SaltString::encode_b64(&salt_bytes)
1082 .map_err(|e| McpxError::Auth(format!("salt encoding failed: {e}")))?;
1083 let hash = Argon2::default()
1084 .hash_password(token.as_bytes(), &salt)
1085 .map_err(|e| McpxError::Auth(format!("argon2id hashing failed: {e}")))?
1086 .to_string();
1087
1088 Ok((token, hash))
1089}
1090
1091fn build_www_authenticate_value(
1092 advertise_resource_metadata: bool,
1093 failure: AuthFailureClass,
1094) -> String {
1095 let (error, error_description) = failure.bearer_error();
1096 if advertise_resource_metadata {
1097 return format!(
1098 "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\", error=\"{error}\", error_description=\"{error_description}\""
1099 );
1100 }
1101 format!("Bearer error=\"{error}\", error_description=\"{error_description}\"")
1102}
1103
1104fn auth_method_label(method: AuthMethod) -> &'static str {
1105 match method {
1106 AuthMethod::MtlsCertificate => "mTLS",
1107 AuthMethod::BearerToken => "bearer token",
1108 AuthMethod::OAuthJwt => "OAuth JWT",
1109 }
1110}
1111
1112#[cfg_attr(not(feature = "oauth"), allow(unused_variables))]
1113fn unauthorized_response(state: &AuthState, failure_class: AuthFailureClass) -> Response {
1114 #[cfg(feature = "oauth")]
1115 let advertise_resource_metadata = state.jwks_cache.is_some();
1116 #[cfg(not(feature = "oauth"))]
1117 let advertise_resource_metadata = false;
1118
1119 let challenge = build_www_authenticate_value(advertise_resource_metadata, failure_class);
1120 (
1121 axum::http::StatusCode::UNAUTHORIZED,
1122 [(header::WWW_AUTHENTICATE, challenge)],
1123 failure_class.response_body(),
1124 )
1125 .into_response()
1126}
1127
1128async fn authenticate_bearer_identity(
1129 state: &AuthState,
1130 token: &str,
1131) -> Result<AuthIdentity, AuthFailureClass> {
1132 let mut failure_class = AuthFailureClass::MissingCredential;
1133
1134 #[cfg(feature = "oauth")]
1135 if let Some(ref cache) = state.jwks_cache
1136 && crate::oauth::looks_like_jwt(token)
1137 {
1138 match cache.validate_token_with_reason(token).await {
1139 Ok(mut id) => {
1140 id.raw_token = Some(SecretString::from(token.to_owned()));
1141 return Ok(id);
1142 }
1143 Err(crate::oauth::JwtValidationFailure::Expired) => {
1144 failure_class = AuthFailureClass::ExpiredCredential;
1145 }
1146 Err(crate::oauth::JwtValidationFailure::Invalid) => {
1147 failure_class = AuthFailureClass::InvalidCredential;
1148 }
1149 }
1150 }
1151
1152 let token = token.to_owned();
1153 let keys = state.api_keys.load_full(); let identity = tokio::task::spawn_blocking(move || verify_bearer_token(&token, &keys))
1157 .await
1158 .ok()
1159 .flatten();
1160
1161 if let Some(id) = identity {
1162 return Ok(id);
1163 }
1164
1165 if failure_class == AuthFailureClass::MissingCredential {
1166 failure_class = AuthFailureClass::InvalidCredential;
1167 }
1168
1169 Err(failure_class)
1170}
1171
1172fn pre_auth_gate(state: &AuthState, peer_addr: Option<SocketAddr>) -> Option<Response> {
1183 let limiter = state.pre_auth_limiter.as_ref()?;
1184 let addr = peer_addr?;
1185 if limiter.check_key(&addr.ip()).is_ok() {
1186 return None;
1187 }
1188 state.counters.record_failure(AuthFailureClass::PreAuthGate);
1189 tracing::warn!(
1190 ip = %addr.ip(),
1191 "auth rate limited by pre-auth gate (request rejected before credential verification)"
1192 );
1193 Some(
1194 McpxError::RateLimited("too many unauthenticated requests from this source".into())
1195 .into_response(),
1196 )
1197}
1198
1199pub(crate) async fn auth_middleware(
1208 state: Arc<AuthState>,
1209 req: Request<Body>,
1210 next: Next,
1211) -> Response {
1212 let tls_info = req.extensions().get::<ConnectInfo<TlsConnInfo>>().cloned();
1217 let peer_addr = req
1218 .extensions()
1219 .get::<ConnectInfo<SocketAddr>>()
1220 .map(|ci| ci.0)
1221 .or_else(|| tls_info.as_ref().map(|ci| ci.0.addr));
1222
1223 if let Some(id) = tls_info.and_then(|ci| ci.0.identity) {
1230 state.log_auth(&id, "mTLS");
1231 let mut req = req;
1232 req.extensions_mut().insert(id);
1233 return next.run(req).await;
1234 }
1235
1236 if let Some(blocked) = pre_auth_gate(&state, peer_addr) {
1240 return blocked;
1241 }
1242
1243 let failure_class = if let Some(value) = req.headers().get(header::AUTHORIZATION) {
1244 match value.to_str().ok().and_then(extract_bearer) {
1245 Some(token) => match authenticate_bearer_identity(&state, token).await {
1246 Ok(id) => {
1247 state.log_auth(&id, auth_method_label(id.method));
1248 let mut req = req;
1249 req.extensions_mut().insert(id);
1250 return next.run(req).await;
1251 }
1252 Err(class) => class,
1253 },
1254 None => AuthFailureClass::InvalidCredential,
1255 }
1256 } else {
1257 AuthFailureClass::MissingCredential
1258 };
1259
1260 tracing::warn!(failure_class = %failure_class.as_str(), "auth failed");
1261
1262 if let (Some(limiter), Some(addr)) = (&state.rate_limiter, peer_addr)
1265 && limiter.check_key(&addr.ip()).is_err()
1266 {
1267 state.counters.record_failure(AuthFailureClass::RateLimited);
1268 tracing::warn!(ip = %addr.ip(), "auth rate limited after repeated failures");
1269 return McpxError::RateLimited("too many failed authentication attempts".into())
1270 .into_response();
1271 }
1272
1273 state.counters.record_failure(failure_class);
1274 unauthorized_response(&state, failure_class)
1275}
1276
1277#[cfg(test)]
1278mod tests {
1279 use super::*;
1280
1281 #[test]
1282 fn generate_and_verify_api_key() {
1283 let (token, hash) = generate_api_key().unwrap();
1284
1285 assert_eq!(token.len(), 43);
1287
1288 assert!(hash.starts_with("$argon2id$"));
1290
1291 let keys = vec![ApiKeyEntry {
1293 name: "test".into(),
1294 hash,
1295 role: "viewer".into(),
1296 expires_at: None,
1297 }];
1298 let id = verify_bearer_token(&token, &keys);
1299 assert!(id.is_some());
1300 let id = id.unwrap();
1301 assert_eq!(id.name, "test");
1302 assert_eq!(id.role, "viewer");
1303 assert_eq!(id.method, AuthMethod::BearerToken);
1304 }
1305
1306 #[test]
1307 fn wrong_token_rejected() {
1308 let (_token, hash) = generate_api_key().unwrap();
1309 let keys = vec![ApiKeyEntry {
1310 name: "test".into(),
1311 hash,
1312 role: "viewer".into(),
1313 expires_at: None,
1314 }];
1315 assert!(verify_bearer_token("wrong-token", &keys).is_none());
1316 }
1317
1318 #[test]
1319 fn expired_key_rejected() {
1320 let (token, hash) = generate_api_key().unwrap();
1321 let keys = vec![ApiKeyEntry {
1322 name: "test".into(),
1323 hash,
1324 role: "viewer".into(),
1325 expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1326 }];
1327 assert!(verify_bearer_token(&token, &keys).is_none());
1328 }
1329
1330 #[test]
1331 fn match_in_last_slot_still_authenticates() {
1332 let (token, hash) = generate_api_key().unwrap();
1333 let (_other_token, other_hash) = generate_api_key().unwrap();
1334 let keys = vec![
1335 ApiKeyEntry {
1336 name: "first".into(),
1337 hash: other_hash.clone(),
1338 role: "viewer".into(),
1339 expires_at: None,
1340 },
1341 ApiKeyEntry {
1342 name: "second".into(),
1343 hash: other_hash,
1344 role: "viewer".into(),
1345 expires_at: None,
1346 },
1347 ApiKeyEntry {
1348 name: "match".into(),
1349 hash,
1350 role: "ops".into(),
1351 expires_at: None,
1352 },
1353 ];
1354 let id = verify_bearer_token(&token, &keys).expect("last-slot match must authenticate");
1355 assert_eq!(id.name, "match");
1356 assert_eq!(id.role, "ops");
1357 }
1358
1359 #[test]
1360 fn expired_slot_before_valid_match_does_not_short_circuit() {
1361 let (token, hash) = generate_api_key().unwrap();
1362 let (_, other_hash) = generate_api_key().unwrap();
1363 let keys = vec![
1364 ApiKeyEntry {
1365 name: "expired".into(),
1366 hash: other_hash,
1367 role: "viewer".into(),
1368 expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1369 },
1370 ApiKeyEntry {
1371 name: "valid".into(),
1372 hash,
1373 role: "ops".into(),
1374 expires_at: None,
1375 },
1376 ];
1377 let id = verify_bearer_token(&token, &keys)
1378 .expect("valid slot following an expired slot must authenticate");
1379 assert_eq!(id.name, "valid");
1380 }
1381
1382 #[test]
1383 fn malformed_hash_slot_does_not_short_circuit() {
1384 let (token, hash) = generate_api_key().unwrap();
1385 let keys = vec![
1386 ApiKeyEntry {
1387 name: "broken".into(),
1388 hash: "this-is-not-a-phc-string".into(),
1389 role: "viewer".into(),
1390 expires_at: None,
1391 },
1392 ApiKeyEntry {
1393 name: "valid".into(),
1394 hash,
1395 role: "ops".into(),
1396 expires_at: None,
1397 },
1398 ];
1399 let id = verify_bearer_token(&token, &keys)
1400 .expect("valid slot following a malformed-hash slot must authenticate");
1401 assert_eq!(id.name, "valid");
1402 }
1403
1404 #[test]
1415 fn rfc_timestamp_parse_rejects_malformed() {
1416 for bad in [
1417 "not-a-date",
1418 "",
1419 "2025-13-01T00:00:00Z", "2025-01-32T00:00:00Z", "2025-01-01T00:00:00", "01/01/2025", "2025-01-01T25:00:00Z", ] {
1425 assert!(
1426 RfcTimestamp::parse(bad).is_err(),
1427 "RfcTimestamp::parse must reject {bad:?}"
1428 );
1429 }
1430 }
1431
1432 #[test]
1433 fn rfc_timestamp_parse_accepts_valid() {
1434 for good in [
1435 "2025-01-01T00:00:00Z",
1436 "2025-01-01T00:00:00+00:00",
1437 "2025-12-31T23:59:59-08:00",
1438 "2099-01-01T00:00:00.123456789Z",
1439 ] {
1440 assert!(
1441 RfcTimestamp::parse(good).is_ok(),
1442 "RfcTimestamp::parse must accept {good:?}"
1443 );
1444 }
1445 }
1446
1447 #[test]
1448 fn api_key_entry_deserialize_rejects_malformed_expires_at() {
1449 let toml = r#"
1454 name = "bad-key"
1455 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1456 role = "viewer"
1457 expires_at = "not-a-date"
1458 "#;
1459 let result: Result<ApiKeyEntry, _> = toml::from_str(toml);
1460 assert!(
1461 result.is_err(),
1462 "deserialization must reject malformed expires_at"
1463 );
1464 }
1465
1466 #[test]
1467 fn api_key_entry_deserialize_accepts_valid_expires_at() {
1468 let toml = r#"
1469 name = "good-key"
1470 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1471 role = "viewer"
1472 expires_at = "2099-01-01T00:00:00Z"
1473 "#;
1474 let entry: ApiKeyEntry = toml::from_str(toml).expect("valid RFC 3339 must deserialize");
1475 assert!(entry.expires_at.is_some());
1476 }
1477
1478 #[test]
1479 fn api_key_entry_deserialize_accepts_missing_expires_at() {
1480 let toml = r#"
1483 name = "eternal-key"
1484 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1485 role = "viewer"
1486 "#;
1487 let entry: ApiKeyEntry = toml::from_str(toml).expect("missing expires_at must deserialize");
1488 assert!(entry.expires_at.is_none());
1489 }
1490
1491 #[test]
1492 fn try_with_expiry_rejects_malformed() {
1493 let entry = ApiKeyEntry::new("k", "hash", "viewer");
1494 assert!(entry.try_with_expiry("not-a-date").is_err());
1495 }
1496
1497 #[test]
1498 fn try_with_expiry_accepts_valid() {
1499 let entry = ApiKeyEntry::new("k", "hash", "viewer")
1500 .try_with_expiry("2099-01-01T00:00:00Z")
1501 .expect("valid RFC 3339 must be accepted");
1502 assert!(entry.expires_at.is_some());
1503 }
1504
1505 #[test]
1506 fn api_key_summary_serializes_expires_at_as_rfc3339() {
1507 let summary = ApiKeySummary {
1512 name: "k".into(),
1513 role: "viewer".into(),
1514 expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
1515 };
1516 let json = serde_json::to_string(&summary).unwrap();
1517 assert!(
1518 json.contains(r#""expires_at":"2030-01-01T00:00:00+00:00""#),
1519 "wire format regressed: {json}"
1520 );
1521 }
1522
1523 #[test]
1524 fn future_expiry_accepted() {
1525 let (token, hash) = generate_api_key().unwrap();
1526 let keys = vec![ApiKeyEntry {
1527 name: "test".into(),
1528 hash,
1529 role: "viewer".into(),
1530 expires_at: Some(RfcTimestamp::parse("2099-01-01T00:00:00Z").unwrap()),
1531 }];
1532 assert!(verify_bearer_token(&token, &keys).is_some());
1533 }
1534
1535 #[test]
1536 fn multiple_keys_first_match_wins() {
1537 let (token, hash) = generate_api_key().unwrap();
1538 let keys = vec![
1539 ApiKeyEntry {
1540 name: "wrong".into(),
1541 hash: "$argon2id$v=19$m=19456,t=2,p=1$invalid$invalid".into(),
1542 role: "ops".into(),
1543 expires_at: None,
1544 },
1545 ApiKeyEntry {
1546 name: "correct".into(),
1547 hash,
1548 role: "deploy".into(),
1549 expires_at: None,
1550 },
1551 ];
1552 let id = verify_bearer_token(&token, &keys).unwrap();
1553 assert_eq!(id.name, "correct");
1554 assert_eq!(id.role, "deploy");
1555 }
1556
1557 #[test]
1558 fn rate_limiter_allows_within_quota() {
1559 let config = RateLimitConfig {
1560 max_attempts_per_minute: 5,
1561 pre_auth_max_per_minute: None,
1562 ..Default::default()
1563 };
1564 let limiter = build_rate_limiter(&config);
1565 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1566
1567 for _ in 0..5 {
1569 assert!(limiter.check_key(&ip).is_ok());
1570 }
1571 assert!(limiter.check_key(&ip).is_err());
1573 }
1574
1575 #[test]
1576 fn rate_limiter_separate_ips() {
1577 let config = RateLimitConfig {
1578 max_attempts_per_minute: 2,
1579 pre_auth_max_per_minute: None,
1580 ..Default::default()
1581 };
1582 let limiter = build_rate_limiter(&config);
1583 let ip1: IpAddr = "10.0.0.1".parse().unwrap();
1584 let ip2: IpAddr = "10.0.0.2".parse().unwrap();
1585
1586 assert!(limiter.check_key(&ip1).is_ok());
1588 assert!(limiter.check_key(&ip1).is_ok());
1589 assert!(limiter.check_key(&ip1).is_err());
1590
1591 assert!(limiter.check_key(&ip2).is_ok());
1593 }
1594
1595 #[test]
1596 fn extract_mtls_identity_from_cn() {
1597 let mut params = rcgen::CertificateParams::new(vec!["test-client.local".into()]).unwrap();
1599 params.distinguished_name = rcgen::DistinguishedName::new();
1600 params
1601 .distinguished_name
1602 .push(rcgen::DnType::CommonName, "test-client");
1603 let cert = params
1604 .self_signed(&rcgen::KeyPair::generate().unwrap())
1605 .unwrap();
1606 let der = cert.der();
1607
1608 let id = extract_mtls_identity(der, "ops").unwrap();
1609 assert_eq!(id.name, "test-client");
1610 assert_eq!(id.role, "ops");
1611 assert_eq!(id.method, AuthMethod::MtlsCertificate);
1612 }
1613
1614 #[test]
1615 fn extract_mtls_identity_falls_back_to_san() {
1616 let mut params =
1618 rcgen::CertificateParams::new(vec!["san-only.example.com".into()]).unwrap();
1619 params.distinguished_name = rcgen::DistinguishedName::new();
1620 let cert = params
1622 .self_signed(&rcgen::KeyPair::generate().unwrap())
1623 .unwrap();
1624 let der = cert.der();
1625
1626 let id = extract_mtls_identity(der, "viewer").unwrap();
1627 assert_eq!(id.name, "san-only.example.com");
1628 assert_eq!(id.role, "viewer");
1629 }
1630
1631 #[test]
1632 fn extract_mtls_identity_invalid_der() {
1633 assert!(extract_mtls_identity(b"not-a-cert", "viewer").is_none());
1634 }
1635
1636 use axum::{
1639 body::Body,
1640 http::{Request, StatusCode},
1641 };
1642 use tower::ServiceExt as _;
1643
1644 fn auth_router(state: Arc<AuthState>) -> axum::Router {
1645 axum::Router::new()
1646 .route("/mcp", axum::routing::post(|| async { "ok" }))
1647 .layer(axum::middleware::from_fn(move |req, next| {
1648 let s = Arc::clone(&state);
1649 auth_middleware(s, req, next)
1650 }))
1651 }
1652
1653 fn test_auth_state(keys: Vec<ApiKeyEntry>) -> Arc<AuthState> {
1654 Arc::new(AuthState {
1655 api_keys: ArcSwap::new(Arc::new(keys)),
1656 rate_limiter: None,
1657 pre_auth_limiter: None,
1658 #[cfg(feature = "oauth")]
1659 jwks_cache: None,
1660 seen_identities: Mutex::new(HashSet::new()),
1661 counters: AuthCounters::default(),
1662 })
1663 }
1664
1665 #[tokio::test]
1666 async fn middleware_rejects_no_credentials() {
1667 let state = test_auth_state(vec![]);
1668 let app = auth_router(Arc::clone(&state));
1669 let req = Request::builder()
1670 .method(axum::http::Method::POST)
1671 .uri("/mcp")
1672 .body(Body::empty())
1673 .unwrap();
1674 let resp = app.oneshot(req).await.unwrap();
1675 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1676 let challenge = resp
1677 .headers()
1678 .get(header::WWW_AUTHENTICATE)
1679 .unwrap()
1680 .to_str()
1681 .unwrap();
1682 assert!(challenge.contains("error=\"invalid_request\""));
1683
1684 let counters = state.counters_snapshot();
1685 assert_eq!(counters.failure_missing_credential, 1);
1686 }
1687
1688 #[tokio::test]
1689 async fn middleware_accepts_valid_bearer() {
1690 let (token, hash) = generate_api_key().unwrap();
1691 let keys = vec![ApiKeyEntry {
1692 name: "test-key".into(),
1693 hash,
1694 role: "ops".into(),
1695 expires_at: None,
1696 }];
1697 let state = test_auth_state(keys);
1698 let app = auth_router(Arc::clone(&state));
1699 let req = Request::builder()
1700 .method(axum::http::Method::POST)
1701 .uri("/mcp")
1702 .header("authorization", format!("Bearer {token}"))
1703 .body(Body::empty())
1704 .unwrap();
1705 let resp = app.oneshot(req).await.unwrap();
1706 assert_eq!(resp.status(), StatusCode::OK);
1707
1708 let counters = state.counters_snapshot();
1709 assert_eq!(counters.success_bearer, 1);
1710 }
1711
1712 #[tokio::test]
1713 async fn middleware_rejects_wrong_bearer() {
1714 let (_token, hash) = generate_api_key().unwrap();
1715 let keys = vec![ApiKeyEntry {
1716 name: "test-key".into(),
1717 hash,
1718 role: "ops".into(),
1719 expires_at: None,
1720 }];
1721 let state = test_auth_state(keys);
1722 let app = auth_router(Arc::clone(&state));
1723 let req = Request::builder()
1724 .method(axum::http::Method::POST)
1725 .uri("/mcp")
1726 .header("authorization", "Bearer wrong-token-here")
1727 .body(Body::empty())
1728 .unwrap();
1729 let resp = app.oneshot(req).await.unwrap();
1730 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1731 let challenge = resp
1732 .headers()
1733 .get(header::WWW_AUTHENTICATE)
1734 .unwrap()
1735 .to_str()
1736 .unwrap();
1737 assert!(challenge.contains("error=\"invalid_token\""));
1738
1739 let counters = state.counters_snapshot();
1740 assert_eq!(counters.failure_invalid_credential, 1);
1741 }
1742
1743 #[tokio::test]
1744 async fn middleware_rate_limits() {
1745 let state = Arc::new(AuthState {
1746 api_keys: ArcSwap::new(Arc::new(vec![])),
1747 rate_limiter: Some(build_rate_limiter(&RateLimitConfig {
1748 max_attempts_per_minute: 1,
1749 pre_auth_max_per_minute: None,
1750 ..Default::default()
1751 })),
1752 pre_auth_limiter: None,
1753 #[cfg(feature = "oauth")]
1754 jwks_cache: None,
1755 seen_identities: Mutex::new(HashSet::new()),
1756 counters: AuthCounters::default(),
1757 });
1758 let app = auth_router(state);
1759
1760 let req = Request::builder()
1762 .method(axum::http::Method::POST)
1763 .uri("/mcp")
1764 .body(Body::empty())
1765 .unwrap();
1766 let resp = app.clone().oneshot(req).await.unwrap();
1767 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1768
1769 }
1774
1775 #[test]
1781 fn rate_limit_semantics_failed_only() {
1782 let config = RateLimitConfig {
1783 max_attempts_per_minute: 3,
1784 pre_auth_max_per_minute: None,
1785 ..Default::default()
1786 };
1787 let limiter = build_rate_limiter(&config);
1788 let ip: IpAddr = "192.168.1.100".parse().unwrap();
1789
1790 assert!(
1792 limiter.check_key(&ip).is_ok(),
1793 "failure 1 should be allowed"
1794 );
1795 assert!(
1796 limiter.check_key(&ip).is_ok(),
1797 "failure 2 should be allowed"
1798 );
1799 assert!(
1800 limiter.check_key(&ip).is_ok(),
1801 "failure 3 should be allowed"
1802 );
1803 assert!(
1804 limiter.check_key(&ip).is_err(),
1805 "failure 4 should be blocked"
1806 );
1807
1808 }
1817
1818 #[test]
1823 fn pre_auth_default_multiplier_is_10x() {
1824 let config = RateLimitConfig {
1825 max_attempts_per_minute: 5,
1826 pre_auth_max_per_minute: None,
1827 ..Default::default()
1828 };
1829 let limiter = build_pre_auth_limiter(&config);
1830 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1831
1832 for i in 0..50 {
1834 assert!(
1835 limiter.check_key(&ip).is_ok(),
1836 "pre-auth attempt {i} (of expected 50) should be allowed under default 10x multiplier"
1837 );
1838 }
1839 assert!(
1841 limiter.check_key(&ip).is_err(),
1842 "pre-auth attempt 51 should be blocked (quota is 50, not unbounded)"
1843 );
1844 }
1845
1846 #[test]
1849 fn pre_auth_explicit_override_wins() {
1850 let config = RateLimitConfig {
1851 max_attempts_per_minute: 100, pre_auth_max_per_minute: Some(2), ..Default::default()
1854 };
1855 let limiter = build_pre_auth_limiter(&config);
1856 let ip: IpAddr = "10.0.0.2".parse().unwrap();
1857
1858 assert!(limiter.check_key(&ip).is_ok(), "attempt 1 allowed");
1859 assert!(limiter.check_key(&ip).is_ok(), "attempt 2 allowed");
1860 assert!(
1861 limiter.check_key(&ip).is_err(),
1862 "attempt 3 must be blocked (explicit override of 2 wins over 10x default of 1000)"
1863 );
1864 }
1865
1866 #[tokio::test]
1872 async fn pre_auth_gate_blocks_before_argon2_verification() {
1873 let (_token, hash) = generate_api_key().unwrap();
1874 let keys = vec![ApiKeyEntry {
1875 name: "test-key".into(),
1876 hash,
1877 role: "ops".into(),
1878 expires_at: None,
1879 }];
1880 let config = RateLimitConfig {
1881 max_attempts_per_minute: 100,
1882 pre_auth_max_per_minute: Some(1),
1883 ..Default::default()
1884 };
1885 let state = Arc::new(AuthState {
1886 api_keys: ArcSwap::new(Arc::new(keys)),
1887 rate_limiter: None,
1888 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1889 #[cfg(feature = "oauth")]
1890 jwks_cache: None,
1891 seen_identities: Mutex::new(HashSet::new()),
1892 counters: AuthCounters::default(),
1893 });
1894 let app = auth_router(Arc::clone(&state));
1895 let peer: SocketAddr = "10.0.0.10:54321".parse().unwrap();
1896
1897 let mut req1 = Request::builder()
1900 .method(axum::http::Method::POST)
1901 .uri("/mcp")
1902 .header("authorization", "Bearer obviously-not-a-real-token")
1903 .body(Body::empty())
1904 .unwrap();
1905 req1.extensions_mut().insert(ConnectInfo(peer));
1906 let resp1 = app.clone().oneshot(req1).await.unwrap();
1907 assert_eq!(
1908 resp1.status(),
1909 StatusCode::UNAUTHORIZED,
1910 "first attempt: gate has quota, falls through to bearer auth which fails with 401"
1911 );
1912
1913 let mut req2 = Request::builder()
1916 .method(axum::http::Method::POST)
1917 .uri("/mcp")
1918 .header("authorization", "Bearer also-not-a-real-token")
1919 .body(Body::empty())
1920 .unwrap();
1921 req2.extensions_mut().insert(ConnectInfo(peer));
1922 let resp2 = app.oneshot(req2).await.unwrap();
1923 assert_eq!(
1924 resp2.status(),
1925 StatusCode::TOO_MANY_REQUESTS,
1926 "second attempt from same IP: pre-auth gate must reject with 429"
1927 );
1928
1929 let counters = state.counters_snapshot();
1930 assert_eq!(
1931 counters.failure_pre_auth_gate, 1,
1932 "exactly one request must have been rejected by the pre-auth gate"
1933 );
1934 assert_eq!(
1938 counters.failure_invalid_credential, 1,
1939 "bearer verification must run exactly once (only the un-gated first request)"
1940 );
1941 }
1942
1943 #[tokio::test]
1950 async fn pre_auth_gate_does_not_throttle_mtls() {
1951 let config = RateLimitConfig {
1952 max_attempts_per_minute: 100,
1953 pre_auth_max_per_minute: Some(1), ..Default::default()
1955 };
1956 let state = Arc::new(AuthState {
1957 api_keys: ArcSwap::new(Arc::new(vec![])),
1958 rate_limiter: None,
1959 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1960 #[cfg(feature = "oauth")]
1961 jwks_cache: None,
1962 seen_identities: Mutex::new(HashSet::new()),
1963 counters: AuthCounters::default(),
1964 });
1965 let app = auth_router(Arc::clone(&state));
1966 let peer: SocketAddr = "10.0.0.20:54321".parse().unwrap();
1967 let identity = AuthIdentity {
1968 name: "cn=test-client".into(),
1969 role: "viewer".into(),
1970 method: AuthMethod::MtlsCertificate,
1971 raw_token: None,
1972 sub: None,
1973 };
1974 let tls_info = TlsConnInfo::new(peer, Some(identity));
1975
1976 for i in 0..3 {
1977 let mut req = Request::builder()
1978 .method(axum::http::Method::POST)
1979 .uri("/mcp")
1980 .body(Body::empty())
1981 .unwrap();
1982 req.extensions_mut().insert(ConnectInfo(tls_info.clone()));
1983 let resp = app.clone().oneshot(req).await.unwrap();
1984 assert_eq!(
1985 resp.status(),
1986 StatusCode::OK,
1987 "mTLS request {i} must succeed: pre-auth gate must not apply to mTLS callers"
1988 );
1989 }
1990
1991 let counters = state.counters_snapshot();
1992 assert_eq!(
1993 counters.failure_pre_auth_gate, 0,
1994 "pre-auth gate counter must remain at zero: mTLS bypasses the gate"
1995 );
1996 assert_eq!(
1997 counters.success_mtls, 3,
1998 "all three mTLS requests must have been counted as successful"
1999 );
2000 }
2001
2002 #[test]
2007 fn extract_bearer_accepts_canonical_case() {
2008 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2009 }
2010
2011 #[test]
2012 fn extract_bearer_is_case_insensitive_per_rfc7235() {
2013 for header in &[
2017 "bearer abc123",
2018 "BEARER abc123",
2019 "BeArEr abc123",
2020 "bEaReR abc123",
2021 ] {
2022 assert_eq!(
2023 extract_bearer(header),
2024 Some("abc123"),
2025 "header {header:?} must parse as a Bearer token (RFC 7235 §2.1)"
2026 );
2027 }
2028 }
2029
2030 #[test]
2031 fn extract_bearer_rejects_other_schemes() {
2032 assert_eq!(extract_bearer("Basic dXNlcjpwYXNz"), None);
2033 assert_eq!(extract_bearer("Digest username=\"x\""), None);
2034 assert_eq!(extract_bearer("Token abc123"), None);
2035 }
2036
2037 #[test]
2038 fn extract_bearer_rejects_malformed() {
2039 assert_eq!(extract_bearer(""), None);
2041 assert_eq!(extract_bearer("Bearer"), None);
2042 assert_eq!(extract_bearer("Bearer "), None);
2043 assert_eq!(extract_bearer("Bearer "), None);
2044 }
2045
2046 #[test]
2047 fn extract_bearer_tolerates_extra_separator_whitespace() {
2048 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2050 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2051 }
2052
2053 #[test]
2059 fn auth_identity_debug_redacts_raw_token() {
2060 let id = AuthIdentity {
2061 name: "alice".into(),
2062 role: "admin".into(),
2063 method: AuthMethod::OAuthJwt,
2064 raw_token: Some(SecretString::from("super-secret-jwt-payload-xyz")),
2065 sub: Some("keycloak-uuid-2f3c8b".into()),
2066 };
2067 let dbg = format!("{id:?}");
2068
2069 assert!(dbg.contains("alice"), "name should be visible: {dbg}");
2071 assert!(dbg.contains("admin"), "role should be visible: {dbg}");
2072 assert!(dbg.contains("OAuthJwt"), "method should be visible: {dbg}");
2073
2074 assert!(
2076 !dbg.contains("super-secret-jwt-payload-xyz"),
2077 "raw_token must be redacted in Debug output: {dbg}"
2078 );
2079 assert!(
2080 !dbg.contains("keycloak-uuid-2f3c8b"),
2081 "sub must be redacted in Debug output: {dbg}"
2082 );
2083 assert!(
2084 dbg.contains("<redacted>"),
2085 "redaction marker missing: {dbg}"
2086 );
2087 }
2088
2089 #[test]
2090 fn auth_identity_debug_marks_absent_secrets() {
2091 let id = AuthIdentity {
2094 name: "viewer-key".into(),
2095 role: "viewer".into(),
2096 method: AuthMethod::BearerToken,
2097 raw_token: None,
2098 sub: None,
2099 };
2100 let dbg = format!("{id:?}");
2101 assert!(
2102 dbg.contains("<none>"),
2103 "absent secrets should be marked: {dbg}"
2104 );
2105 assert!(
2106 !dbg.contains("<redacted>"),
2107 "no <redacted> marker when secrets are absent: {dbg}"
2108 );
2109 }
2110
2111 #[test]
2112 fn api_key_entry_debug_redacts_hash() {
2113 let entry = ApiKeyEntry {
2114 name: "viewer-key".into(),
2115 hash: "$argon2id$v=19$m=19456,t=2,p=1$c2FsdHNhbHQ$h4sh3dPa55w0rd".into(),
2117 role: "viewer".into(),
2118 expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
2119 };
2120 let dbg = format!("{entry:?}");
2121
2122 assert!(dbg.contains("viewer-key"));
2124 assert!(dbg.contains("viewer"));
2125 assert!(dbg.contains("2030-01-01T00:00:00+00:00"));
2126
2127 assert!(
2129 !dbg.contains("$argon2id$"),
2130 "argon2 hash leaked into Debug output: {dbg}"
2131 );
2132 assert!(
2133 !dbg.contains("h4sh3dPa55w0rd"),
2134 "hash digest leaked into Debug output: {dbg}"
2135 );
2136 assert!(
2137 dbg.contains("<redacted>"),
2138 "redaction marker missing: {dbg}"
2139 );
2140 }
2141
2142 #[test]
2153 fn auth_failure_class_as_str_exact_strings() {
2154 assert_eq!(
2155 AuthFailureClass::MissingCredential.as_str(),
2156 "missing_credential"
2157 );
2158 assert_eq!(
2159 AuthFailureClass::InvalidCredential.as_str(),
2160 "invalid_credential"
2161 );
2162 assert_eq!(
2163 AuthFailureClass::ExpiredCredential.as_str(),
2164 "expired_credential"
2165 );
2166 assert_eq!(AuthFailureClass::RateLimited.as_str(), "rate_limited");
2167 assert_eq!(AuthFailureClass::PreAuthGate.as_str(), "pre_auth_gate");
2168 }
2169
2170 #[test]
2171 fn auth_failure_class_response_body_exact_strings() {
2172 assert_eq!(
2173 AuthFailureClass::MissingCredential.response_body(),
2174 "unauthorized: missing credential"
2175 );
2176 assert_eq!(
2177 AuthFailureClass::InvalidCredential.response_body(),
2178 "unauthorized: invalid credential"
2179 );
2180 assert_eq!(
2181 AuthFailureClass::ExpiredCredential.response_body(),
2182 "unauthorized: expired credential"
2183 );
2184 assert_eq!(
2185 AuthFailureClass::RateLimited.response_body(),
2186 "rate limited"
2187 );
2188 assert_eq!(
2189 AuthFailureClass::PreAuthGate.response_body(),
2190 "rate limited (pre-auth)"
2191 );
2192 }
2193
2194 #[test]
2195 fn auth_failure_class_bearer_error_exact_strings() {
2196 assert_eq!(
2197 AuthFailureClass::MissingCredential.bearer_error(),
2198 (
2199 "invalid_request",
2200 "missing bearer token or mTLS client certificate"
2201 )
2202 );
2203 assert_eq!(
2204 AuthFailureClass::InvalidCredential.bearer_error(),
2205 ("invalid_token", "token is invalid")
2206 );
2207 assert_eq!(
2208 AuthFailureClass::ExpiredCredential.bearer_error(),
2209 ("invalid_token", "token is expired")
2210 );
2211 assert_eq!(
2212 AuthFailureClass::RateLimited.bearer_error(),
2213 ("invalid_request", "too many failed authentication attempts")
2214 );
2215 assert_eq!(
2216 AuthFailureClass::PreAuthGate.bearer_error(),
2217 (
2218 "invalid_request",
2219 "too many unauthenticated requests from this source"
2220 )
2221 );
2222 }
2223
2224 #[test]
2233 fn auth_config_summary_bearer_true_when_keys_present() {
2234 let (_token, hash) = generate_api_key().unwrap();
2235 let cfg = AuthConfig::with_keys(vec![ApiKeyEntry::new("k", hash, "viewer")]);
2236 let s = cfg.summary();
2237 assert!(s.enabled, "summary.enabled must reflect AuthConfig.enabled");
2238 assert!(
2239 s.bearer,
2240 "summary.bearer must be true when api_keys is non-empty (kills `!` deletion at L615)"
2241 );
2242 assert!(!s.mtls, "summary.mtls must be false when mtls is None");
2243 assert!(!s.oauth, "summary.oauth must be false when oauth is None");
2244 assert_eq!(s.api_keys.len(), 1);
2245 assert_eq!(s.api_keys[0].name, "k");
2246 assert_eq!(s.api_keys[0].role, "viewer");
2247 }
2248
2249 #[test]
2250 fn auth_config_summary_bearer_false_when_no_keys() {
2251 let cfg = AuthConfig::with_keys(vec![]);
2252 let s = cfg.summary();
2253 assert!(
2254 !s.bearer,
2255 "summary.bearer must be false when api_keys is empty (kills `!` deletion at L615)"
2256 );
2257 assert!(s.api_keys.is_empty());
2258 }
2259}