1use std::{
10 collections::HashSet,
11 net::{IpAddr, SocketAddr},
12 num::NonZeroU32,
13 path::PathBuf,
14 sync::{
15 Arc, Mutex,
16 atomic::{AtomicU64, Ordering},
17 },
18 time::Duration,
19};
20
21use arc_swap::ArcSwap;
22use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
23use axum::{
24 body::Body,
25 extract::ConnectInfo,
26 http::{Request, header},
27 middleware::Next,
28 response::{IntoResponse, Response},
29};
30use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
31use secrecy::SecretString;
32use serde::Deserialize;
33use x509_parser::prelude::*;
34
35use crate::{bounded_limiter::BoundedKeyedLimiter, error::McpxError};
36
37#[derive(Debug, Clone)]
39#[non_exhaustive]
40pub struct AuthIdentity {
41 pub name: String,
43 pub role: String,
45 pub method: AuthMethod,
47 pub raw_token: Option<SecretString>,
53 pub sub: Option<String>,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60#[non_exhaustive]
61pub enum AuthMethod {
62 BearerToken,
64 MtlsCertificate,
66 OAuthJwt,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71enum AuthFailureClass {
72 MissingCredential,
73 InvalidCredential,
74 #[cfg_attr(not(feature = "oauth"), allow(dead_code))]
75 ExpiredCredential,
76 RateLimited,
78 PreAuthGate,
81}
82
83impl AuthFailureClass {
84 fn as_str(self) -> &'static str {
85 match self {
86 Self::MissingCredential => "missing_credential",
87 Self::InvalidCredential => "invalid_credential",
88 Self::ExpiredCredential => "expired_credential",
89 Self::RateLimited => "rate_limited",
90 Self::PreAuthGate => "pre_auth_gate",
91 }
92 }
93
94 fn bearer_error(self) -> (&'static str, &'static str) {
95 match self {
96 Self::MissingCredential => (
97 "invalid_request",
98 "missing bearer token or mTLS client certificate",
99 ),
100 Self::InvalidCredential => ("invalid_token", "token is invalid"),
101 Self::ExpiredCredential => ("invalid_token", "token is expired"),
102 Self::RateLimited => ("invalid_request", "too many failed authentication attempts"),
103 Self::PreAuthGate => (
104 "invalid_request",
105 "too many unauthenticated requests from this source",
106 ),
107 }
108 }
109
110 fn response_body(self) -> &'static str {
111 match self {
112 Self::MissingCredential => "unauthorized: missing credential",
113 Self::InvalidCredential => "unauthorized: invalid credential",
114 Self::ExpiredCredential => "unauthorized: expired credential",
115 Self::RateLimited => "rate limited",
116 Self::PreAuthGate => "rate limited (pre-auth)",
117 }
118 }
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
123#[non_exhaustive]
124pub struct AuthCountersSnapshot {
125 pub success_mtls: u64,
127 pub success_bearer: u64,
129 pub success_oauth_jwt: u64,
131 pub failure_missing_credential: u64,
133 pub failure_invalid_credential: u64,
135 pub failure_expired_credential: u64,
137 pub failure_rate_limited: u64,
139 pub failure_pre_auth_gate: u64,
142}
143
144#[derive(Debug, Default)]
146pub(crate) struct AuthCounters {
147 success_mtls: AtomicU64,
148 success_bearer: AtomicU64,
149 success_oauth_jwt: AtomicU64,
150 failure_missing_credential: AtomicU64,
151 failure_invalid_credential: AtomicU64,
152 failure_expired_credential: AtomicU64,
153 failure_rate_limited: AtomicU64,
154 failure_pre_auth_gate: AtomicU64,
155}
156
157impl AuthCounters {
158 fn record_success(&self, method: AuthMethod) {
159 match method {
160 AuthMethod::MtlsCertificate => {
161 self.success_mtls.fetch_add(1, Ordering::Relaxed);
162 }
163 AuthMethod::BearerToken => {
164 self.success_bearer.fetch_add(1, Ordering::Relaxed);
165 }
166 AuthMethod::OAuthJwt => {
167 self.success_oauth_jwt.fetch_add(1, Ordering::Relaxed);
168 }
169 }
170 }
171
172 fn record_failure(&self, class: AuthFailureClass) {
173 match class {
174 AuthFailureClass::MissingCredential => {
175 self.failure_missing_credential
176 .fetch_add(1, Ordering::Relaxed);
177 }
178 AuthFailureClass::InvalidCredential => {
179 self.failure_invalid_credential
180 .fetch_add(1, Ordering::Relaxed);
181 }
182 AuthFailureClass::ExpiredCredential => {
183 self.failure_expired_credential
184 .fetch_add(1, Ordering::Relaxed);
185 }
186 AuthFailureClass::RateLimited => {
187 self.failure_rate_limited.fetch_add(1, Ordering::Relaxed);
188 }
189 AuthFailureClass::PreAuthGate => {
190 self.failure_pre_auth_gate.fetch_add(1, Ordering::Relaxed);
191 }
192 }
193 }
194
195 fn snapshot(&self) -> AuthCountersSnapshot {
196 AuthCountersSnapshot {
197 success_mtls: self.success_mtls.load(Ordering::Relaxed),
198 success_bearer: self.success_bearer.load(Ordering::Relaxed),
199 success_oauth_jwt: self.success_oauth_jwt.load(Ordering::Relaxed),
200 failure_missing_credential: self.failure_missing_credential.load(Ordering::Relaxed),
201 failure_invalid_credential: self.failure_invalid_credential.load(Ordering::Relaxed),
202 failure_expired_credential: self.failure_expired_credential.load(Ordering::Relaxed),
203 failure_rate_limited: self.failure_rate_limited.load(Ordering::Relaxed),
204 failure_pre_auth_gate: self.failure_pre_auth_gate.load(Ordering::Relaxed),
205 }
206 }
207}
208
209#[derive(Debug, Clone, Deserialize)]
211#[non_exhaustive]
212pub struct ApiKeyEntry {
213 pub name: String,
215 pub hash: String,
217 pub role: String,
219 pub expires_at: Option<String>,
221}
222
223impl ApiKeyEntry {
224 #[must_use]
226 pub fn new(name: impl Into<String>, hash: impl Into<String>, role: impl Into<String>) -> Self {
227 Self {
228 name: name.into(),
229 hash: hash.into(),
230 role: role.into(),
231 expires_at: None,
232 }
233 }
234
235 #[must_use]
237 pub fn with_expiry(mut self, expires_at: impl Into<String>) -> Self {
238 self.expires_at = Some(expires_at.into());
239 self
240 }
241}
242
243#[derive(Debug, Clone, Deserialize)]
245#[allow(
246 clippy::struct_excessive_bools,
247 reason = "mTLS CRL behavior is intentionally configured as independent booleans"
248)]
249#[non_exhaustive]
250pub struct MtlsConfig {
251 pub ca_cert_path: PathBuf,
253 #[serde(default)]
256 pub required: bool,
257 #[serde(default = "default_mtls_role")]
260 pub default_role: String,
261 #[serde(default = "default_true")]
264 pub crl_enabled: bool,
265 #[serde(default, with = "humantime_serde::option")]
268 pub crl_refresh_interval: Option<Duration>,
269 #[serde(default = "default_crl_fetch_timeout", with = "humantime_serde")]
271 pub crl_fetch_timeout: Duration,
272 #[serde(default = "default_crl_stale_grace", with = "humantime_serde")]
275 pub crl_stale_grace: Duration,
276 #[serde(default)]
279 pub crl_deny_on_unavailable: bool,
280 #[serde(default)]
282 pub crl_end_entity_only: bool,
283 #[serde(default = "default_true")]
292 pub crl_allow_http: bool,
293 #[serde(default = "default_true")]
295 pub crl_enforce_expiration: bool,
296 #[serde(default = "default_crl_max_concurrent_fetches")]
302 pub crl_max_concurrent_fetches: usize,
303 #[serde(default = "default_crl_max_response_bytes")]
307 pub crl_max_response_bytes: u64,
308 #[serde(default = "default_crl_discovery_rate_per_min")]
320 pub crl_discovery_rate_per_min: u32,
321 #[serde(default = "default_crl_max_host_semaphores")]
328 pub crl_max_host_semaphores: usize,
329 #[serde(default = "default_crl_max_seen_urls")]
333 pub crl_max_seen_urls: usize,
334 #[serde(default = "default_crl_max_cache_entries")]
338 pub crl_max_cache_entries: usize,
339}
340
341fn default_mtls_role() -> String {
342 "viewer".into()
343}
344
345const fn default_true() -> bool {
346 true
347}
348
349const fn default_crl_fetch_timeout() -> Duration {
350 Duration::from_secs(30)
351}
352
353const fn default_crl_stale_grace() -> Duration {
354 Duration::from_hours(24)
355}
356
357const fn default_crl_max_concurrent_fetches() -> usize {
358 4
359}
360
361const fn default_crl_max_response_bytes() -> u64 {
362 5 * 1024 * 1024
363}
364
365const fn default_crl_discovery_rate_per_min() -> u32 {
366 60
367}
368
369const fn default_crl_max_host_semaphores() -> usize {
370 1024
371}
372
373const fn default_crl_max_seen_urls() -> usize {
374 4096
375}
376
377const fn default_crl_max_cache_entries() -> usize {
378 1024
379}
380
381#[derive(Debug, Clone, Deserialize)]
396#[non_exhaustive]
397pub struct RateLimitConfig {
398 #[serde(default = "default_max_attempts")]
401 pub max_attempts_per_minute: u32,
402 #[serde(default)]
410 pub pre_auth_max_per_minute: Option<u32>,
411 #[serde(default = "default_max_tracked_keys")]
416 pub max_tracked_keys: usize,
417 #[serde(default = "default_idle_eviction", with = "humantime_serde")]
420 pub idle_eviction: Duration,
421}
422
423impl Default for RateLimitConfig {
424 fn default() -> Self {
425 Self {
426 max_attempts_per_minute: default_max_attempts(),
427 pre_auth_max_per_minute: None,
428 max_tracked_keys: default_max_tracked_keys(),
429 idle_eviction: default_idle_eviction(),
430 }
431 }
432}
433
434impl RateLimitConfig {
435 #[must_use]
439 pub fn new(max_attempts_per_minute: u32) -> Self {
440 Self {
441 max_attempts_per_minute,
442 ..Self::default()
443 }
444 }
445
446 #[must_use]
449 pub fn with_pre_auth_max_per_minute(mut self, quota: u32) -> Self {
450 self.pre_auth_max_per_minute = Some(quota);
451 self
452 }
453
454 #[must_use]
456 pub fn with_max_tracked_keys(mut self, max: usize) -> Self {
457 self.max_tracked_keys = max;
458 self
459 }
460
461 #[must_use]
463 pub fn with_idle_eviction(mut self, idle: Duration) -> Self {
464 self.idle_eviction = idle;
465 self
466 }
467}
468
469fn default_max_attempts() -> u32 {
470 30
471}
472
473fn default_max_tracked_keys() -> usize {
474 10_000
475}
476
477fn default_idle_eviction() -> Duration {
478 Duration::from_mins(15)
479}
480
481#[derive(Debug, Clone, Default, Deserialize)]
483#[non_exhaustive]
484pub struct AuthConfig {
485 #[serde(default)]
487 pub enabled: bool,
488 #[serde(default)]
490 pub api_keys: Vec<ApiKeyEntry>,
491 pub mtls: Option<MtlsConfig>,
493 pub rate_limit: Option<RateLimitConfig>,
495 #[cfg(feature = "oauth")]
497 pub oauth: Option<crate::oauth::OAuthConfig>,
498}
499
500impl AuthConfig {
501 #[must_use]
503 pub fn with_keys(keys: Vec<ApiKeyEntry>) -> Self {
504 Self {
505 enabled: true,
506 api_keys: keys,
507 mtls: None,
508 rate_limit: None,
509 #[cfg(feature = "oauth")]
510 oauth: None,
511 }
512 }
513
514 #[must_use]
516 pub fn with_rate_limit(mut self, rate_limit: RateLimitConfig) -> Self {
517 self.rate_limit = Some(rate_limit);
518 self
519 }
520}
521
522#[derive(Debug, Clone, serde::Serialize)]
526#[non_exhaustive]
527pub struct ApiKeySummary {
528 pub name: String,
530 pub role: String,
532 pub expires_at: Option<String>,
534}
535
536#[derive(Debug, Clone, serde::Serialize)]
538#[allow(
539 clippy::struct_excessive_bools,
540 reason = "this is a flat summary of independent auth-method booleans"
541)]
542#[non_exhaustive]
543pub struct AuthConfigSummary {
544 pub enabled: bool,
546 pub bearer: bool,
548 pub mtls: bool,
550 pub oauth: bool,
552 pub api_keys: Vec<ApiKeySummary>,
554}
555
556impl AuthConfig {
557 #[must_use]
559 pub fn summary(&self) -> AuthConfigSummary {
560 AuthConfigSummary {
561 enabled: self.enabled,
562 bearer: !self.api_keys.is_empty(),
563 mtls: self.mtls.is_some(),
564 #[cfg(feature = "oauth")]
565 oauth: self.oauth.is_some(),
566 #[cfg(not(feature = "oauth"))]
567 oauth: false,
568 api_keys: self
569 .api_keys
570 .iter()
571 .map(|k| ApiKeySummary {
572 name: k.name.clone(),
573 role: k.role.clone(),
574 expires_at: k.expires_at.clone(),
575 })
576 .collect(),
577 }
578 }
579}
580
581pub(crate) type KeyedLimiter = BoundedKeyedLimiter<IpAddr>;
584
585#[derive(Clone, Debug)]
595#[non_exhaustive]
596pub(crate) struct TlsConnInfo {
597 pub addr: SocketAddr,
599 pub identity: Option<AuthIdentity>,
602}
603
604impl TlsConnInfo {
605 #[must_use]
607 pub(crate) const fn new(addr: SocketAddr, identity: Option<AuthIdentity>) -> Self {
608 Self { addr, identity }
609 }
610}
611
612#[allow(
617 missing_debug_implementations,
618 reason = "contains governor RateLimiter and JwksCache without Debug impls"
619)]
620#[non_exhaustive]
621pub(crate) struct AuthState {
622 pub api_keys: ArcSwap<Vec<ApiKeyEntry>>,
624 pub rate_limiter: Option<Arc<KeyedLimiter>>,
626 pub pre_auth_limiter: Option<Arc<KeyedLimiter>>,
629 #[cfg(feature = "oauth")]
630 pub jwks_cache: Option<Arc<crate::oauth::JwksCache>>,
632 pub seen_identities: Mutex<HashSet<String>>,
635 pub counters: AuthCounters,
637}
638
639impl AuthState {
640 pub(crate) fn reload_keys(&self, keys: Vec<ApiKeyEntry>) {
646 let count = keys.len();
647 self.api_keys.store(Arc::new(keys));
648 tracing::info!(keys = count, "API keys reloaded");
649 }
650
651 #[must_use]
653 pub(crate) fn counters_snapshot(&self) -> AuthCountersSnapshot {
654 self.counters.snapshot()
655 }
656
657 #[must_use]
659 pub(crate) fn api_key_summaries(&self) -> Vec<ApiKeySummary> {
660 self.api_keys
661 .load()
662 .iter()
663 .map(|k| ApiKeySummary {
664 name: k.name.clone(),
665 role: k.role.clone(),
666 expires_at: k.expires_at.clone(),
667 })
668 .collect()
669 }
670
671 fn log_auth(&self, id: &AuthIdentity, method: &str) {
673 self.counters.record_success(id.method);
674 let first = self
675 .seen_identities
676 .lock()
677 .unwrap_or_else(std::sync::PoisonError::into_inner)
678 .insert(id.name.clone());
679 if first {
680 tracing::info!(name = %id.name, role = %id.role, "{method} authenticated");
681 } else {
682 tracing::debug!(name = %id.name, role = %id.role, "{method} authenticated");
683 }
684 }
685}
686
687const DEFAULT_AUTH_RATE: NonZeroU32 = NonZeroU32::new(30).unwrap();
690
691#[must_use]
693pub(crate) fn build_rate_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
694 let quota = governor::Quota::per_minute(
695 NonZeroU32::new(config.max_attempts_per_minute).unwrap_or(DEFAULT_AUTH_RATE),
696 );
697 Arc::new(BoundedKeyedLimiter::new(
698 quota,
699 config.max_tracked_keys,
700 config.idle_eviction,
701 ))
702}
703
704#[must_use]
711pub(crate) fn build_pre_auth_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
712 let resolved = config.pre_auth_max_per_minute.unwrap_or_else(|| {
713 config
714 .max_attempts_per_minute
715 .saturating_mul(PRE_AUTH_DEFAULT_MULTIPLIER)
716 });
717 let quota =
718 governor::Quota::per_minute(NonZeroU32::new(resolved).unwrap_or(DEFAULT_PRE_AUTH_RATE));
719 Arc::new(BoundedKeyedLimiter::new(
720 quota,
721 config.max_tracked_keys,
722 config.idle_eviction,
723 ))
724}
725
726const PRE_AUTH_DEFAULT_MULTIPLIER: u32 = 10;
729
730const DEFAULT_PRE_AUTH_RATE: NonZeroU32 = NonZeroU32::new(300).unwrap();
734
735#[must_use]
740pub fn extract_mtls_identity(cert_der: &[u8], default_role: &str) -> Option<AuthIdentity> {
741 let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
742
743 let cn = cert
745 .subject()
746 .iter_common_name()
747 .next()
748 .and_then(|attr| attr.as_str().ok())
749 .map(String::from);
750
751 let name = cn.or_else(|| {
753 cert.subject_alternative_name()
754 .ok()
755 .flatten()
756 .and_then(|san| {
757 #[allow(clippy::wildcard_enum_match_arm)]
758 san.value.general_names.iter().find_map(|gn| match gn {
759 GeneralName::DNSName(dns) => Some((*dns).to_owned()),
760 _ => None,
761 })
762 })
763 })?;
764
765 if !name
767 .chars()
768 .all(|c| c.is_alphanumeric() || matches!(c, '-' | '.' | '_' | '@'))
769 {
770 tracing::warn!(cn = %name, "mTLS identity rejected: invalid characters in CN/SAN");
771 return None;
772 }
773
774 Some(AuthIdentity {
775 name,
776 role: default_role.to_owned(),
777 method: AuthMethod::MtlsCertificate,
778 raw_token: None,
779 sub: None,
780 })
781}
782
783#[must_use]
791pub fn verify_bearer_token(token: &str, keys: &[ApiKeyEntry]) -> Option<AuthIdentity> {
792 let now = chrono::Utc::now();
793
794 let mut result: Option<AuthIdentity> = None;
797
798 for key in keys {
799 if let Some(ref expires) = key.expires_at
801 && let Ok(exp) = chrono::DateTime::parse_from_rfc3339(expires)
802 && exp < now
803 {
804 continue;
805 }
806
807 if result.is_none()
810 && let Ok(parsed_hash) = PasswordHash::new(&key.hash)
811 && Argon2::default()
812 .verify_password(token.as_bytes(), &parsed_hash)
813 .is_ok()
814 {
815 result = Some(AuthIdentity {
816 name: key.name.clone(),
817 role: key.role.clone(),
818 method: AuthMethod::BearerToken,
819 raw_token: None,
820 sub: None,
821 });
822 }
823 }
824 result
825}
826
827pub fn generate_api_key() -> Result<(String, String), McpxError> {
837 let mut token_bytes = [0u8; 32];
838 rand::fill(&mut token_bytes);
839 let token = URL_SAFE_NO_PAD.encode(token_bytes);
840
841 let mut salt_bytes = [0u8; 16];
843 rand::fill(&mut salt_bytes);
844 let salt = SaltString::encode_b64(&salt_bytes)
845 .map_err(|e| McpxError::Auth(format!("salt encoding failed: {e}")))?;
846 let hash = Argon2::default()
847 .hash_password(token.as_bytes(), &salt)
848 .map_err(|e| McpxError::Auth(format!("argon2id hashing failed: {e}")))?
849 .to_string();
850
851 Ok((token, hash))
852}
853
854fn build_www_authenticate_value(
855 advertise_resource_metadata: bool,
856 failure: AuthFailureClass,
857) -> String {
858 let (error, error_description) = failure.bearer_error();
859 if advertise_resource_metadata {
860 return format!(
861 "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\", error=\"{error}\", error_description=\"{error_description}\""
862 );
863 }
864 format!("Bearer error=\"{error}\", error_description=\"{error_description}\"")
865}
866
867fn auth_method_label(method: AuthMethod) -> &'static str {
868 match method {
869 AuthMethod::MtlsCertificate => "mTLS",
870 AuthMethod::BearerToken => "bearer token",
871 AuthMethod::OAuthJwt => "OAuth JWT",
872 }
873}
874
875#[cfg_attr(not(feature = "oauth"), allow(unused_variables))]
876fn unauthorized_response(state: &AuthState, failure_class: AuthFailureClass) -> Response {
877 #[cfg(feature = "oauth")]
878 let advertise_resource_metadata = state.jwks_cache.is_some();
879 #[cfg(not(feature = "oauth"))]
880 let advertise_resource_metadata = false;
881
882 let challenge = build_www_authenticate_value(advertise_resource_metadata, failure_class);
883 (
884 axum::http::StatusCode::UNAUTHORIZED,
885 [(header::WWW_AUTHENTICATE, challenge)],
886 failure_class.response_body(),
887 )
888 .into_response()
889}
890
891async fn authenticate_bearer_identity(
892 state: &AuthState,
893 token: &str,
894) -> Result<AuthIdentity, AuthFailureClass> {
895 let mut failure_class = AuthFailureClass::MissingCredential;
896
897 #[cfg(feature = "oauth")]
898 if let Some(ref cache) = state.jwks_cache
899 && crate::oauth::looks_like_jwt(token)
900 {
901 match cache.validate_token_with_reason(token).await {
902 Ok(mut id) => {
903 id.raw_token = Some(SecretString::from(token.to_owned()));
904 return Ok(id);
905 }
906 Err(crate::oauth::JwtValidationFailure::Expired) => {
907 failure_class = AuthFailureClass::ExpiredCredential;
908 }
909 Err(crate::oauth::JwtValidationFailure::Invalid) => {
910 failure_class = AuthFailureClass::InvalidCredential;
911 }
912 }
913 }
914
915 let token = token.to_owned();
916 let keys = state.api_keys.load_full(); let identity = tokio::task::spawn_blocking(move || verify_bearer_token(&token, &keys))
920 .await
921 .ok()
922 .flatten();
923
924 if let Some(id) = identity {
925 return Ok(id);
926 }
927
928 if failure_class == AuthFailureClass::MissingCredential {
929 failure_class = AuthFailureClass::InvalidCredential;
930 }
931
932 Err(failure_class)
933}
934
935fn pre_auth_gate(state: &AuthState, peer_addr: Option<SocketAddr>) -> Option<Response> {
946 let limiter = state.pre_auth_limiter.as_ref()?;
947 let addr = peer_addr?;
948 if limiter.check_key(&addr.ip()).is_ok() {
949 return None;
950 }
951 state.counters.record_failure(AuthFailureClass::PreAuthGate);
952 tracing::warn!(
953 ip = %addr.ip(),
954 "auth rate limited by pre-auth gate (request rejected before credential verification)"
955 );
956 Some(
957 McpxError::RateLimited("too many unauthenticated requests from this source".into())
958 .into_response(),
959 )
960}
961
962pub(crate) async fn auth_middleware(
971 state: Arc<AuthState>,
972 req: Request<Body>,
973 next: Next,
974) -> Response {
975 let tls_info = req.extensions().get::<ConnectInfo<TlsConnInfo>>().cloned();
980 let peer_addr = req
981 .extensions()
982 .get::<ConnectInfo<SocketAddr>>()
983 .map(|ci| ci.0)
984 .or_else(|| tls_info.as_ref().map(|ci| ci.0.addr));
985
986 if let Some(id) = tls_info.and_then(|ci| ci.0.identity) {
993 state.log_auth(&id, "mTLS");
994 let mut req = req;
995 req.extensions_mut().insert(id);
996 return next.run(req).await;
997 }
998
999 if let Some(blocked) = pre_auth_gate(&state, peer_addr) {
1003 return blocked;
1004 }
1005
1006 let failure_class = if let Some(value) = req.headers().get(header::AUTHORIZATION) {
1007 match value.to_str().ok().and_then(|v| v.strip_prefix("Bearer ")) {
1008 Some(token) => match authenticate_bearer_identity(&state, token).await {
1009 Ok(id) => {
1010 state.log_auth(&id, auth_method_label(id.method));
1011 let mut req = req;
1012 req.extensions_mut().insert(id);
1013 return next.run(req).await;
1014 }
1015 Err(class) => class,
1016 },
1017 None => AuthFailureClass::InvalidCredential,
1018 }
1019 } else {
1020 AuthFailureClass::MissingCredential
1021 };
1022
1023 tracing::warn!(failure_class = %failure_class.as_str(), "auth failed");
1024
1025 if let (Some(limiter), Some(addr)) = (&state.rate_limiter, peer_addr)
1028 && limiter.check_key(&addr.ip()).is_err()
1029 {
1030 state.counters.record_failure(AuthFailureClass::RateLimited);
1031 tracing::warn!(ip = %addr.ip(), "auth rate limited after repeated failures");
1032 return McpxError::RateLimited("too many failed authentication attempts".into())
1033 .into_response();
1034 }
1035
1036 state.counters.record_failure(failure_class);
1037 unauthorized_response(&state, failure_class)
1038}
1039
1040#[cfg(test)]
1041mod tests {
1042 use super::*;
1043
1044 #[test]
1045 fn generate_and_verify_api_key() {
1046 let (token, hash) = generate_api_key().unwrap();
1047
1048 assert_eq!(token.len(), 43);
1050
1051 assert!(hash.starts_with("$argon2id$"));
1053
1054 let keys = vec![ApiKeyEntry {
1056 name: "test".into(),
1057 hash,
1058 role: "viewer".into(),
1059 expires_at: None,
1060 }];
1061 let id = verify_bearer_token(&token, &keys);
1062 assert!(id.is_some());
1063 let id = id.unwrap();
1064 assert_eq!(id.name, "test");
1065 assert_eq!(id.role, "viewer");
1066 assert_eq!(id.method, AuthMethod::BearerToken);
1067 }
1068
1069 #[test]
1070 fn wrong_token_rejected() {
1071 let (_token, hash) = generate_api_key().unwrap();
1072 let keys = vec![ApiKeyEntry {
1073 name: "test".into(),
1074 hash,
1075 role: "viewer".into(),
1076 expires_at: None,
1077 }];
1078 assert!(verify_bearer_token("wrong-token", &keys).is_none());
1079 }
1080
1081 #[test]
1082 fn expired_key_rejected() {
1083 let (token, hash) = generate_api_key().unwrap();
1084 let keys = vec![ApiKeyEntry {
1085 name: "test".into(),
1086 hash,
1087 role: "viewer".into(),
1088 expires_at: Some("2020-01-01T00:00:00Z".into()),
1089 }];
1090 assert!(verify_bearer_token(&token, &keys).is_none());
1091 }
1092
1093 #[test]
1094 fn future_expiry_accepted() {
1095 let (token, hash) = generate_api_key().unwrap();
1096 let keys = vec![ApiKeyEntry {
1097 name: "test".into(),
1098 hash,
1099 role: "viewer".into(),
1100 expires_at: Some("2099-01-01T00:00:00Z".into()),
1101 }];
1102 assert!(verify_bearer_token(&token, &keys).is_some());
1103 }
1104
1105 #[test]
1106 fn multiple_keys_first_match_wins() {
1107 let (token, hash) = generate_api_key().unwrap();
1108 let keys = vec![
1109 ApiKeyEntry {
1110 name: "wrong".into(),
1111 hash: "$argon2id$v=19$m=19456,t=2,p=1$invalid$invalid".into(),
1112 role: "ops".into(),
1113 expires_at: None,
1114 },
1115 ApiKeyEntry {
1116 name: "correct".into(),
1117 hash,
1118 role: "deploy".into(),
1119 expires_at: None,
1120 },
1121 ];
1122 let id = verify_bearer_token(&token, &keys).unwrap();
1123 assert_eq!(id.name, "correct");
1124 assert_eq!(id.role, "deploy");
1125 }
1126
1127 #[test]
1128 fn rate_limiter_allows_within_quota() {
1129 let config = RateLimitConfig {
1130 max_attempts_per_minute: 5,
1131 pre_auth_max_per_minute: None,
1132 ..Default::default()
1133 };
1134 let limiter = build_rate_limiter(&config);
1135 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1136
1137 for _ in 0..5 {
1139 assert!(limiter.check_key(&ip).is_ok());
1140 }
1141 assert!(limiter.check_key(&ip).is_err());
1143 }
1144
1145 #[test]
1146 fn rate_limiter_separate_ips() {
1147 let config = RateLimitConfig {
1148 max_attempts_per_minute: 2,
1149 pre_auth_max_per_minute: None,
1150 ..Default::default()
1151 };
1152 let limiter = build_rate_limiter(&config);
1153 let ip1: IpAddr = "10.0.0.1".parse().unwrap();
1154 let ip2: IpAddr = "10.0.0.2".parse().unwrap();
1155
1156 assert!(limiter.check_key(&ip1).is_ok());
1158 assert!(limiter.check_key(&ip1).is_ok());
1159 assert!(limiter.check_key(&ip1).is_err());
1160
1161 assert!(limiter.check_key(&ip2).is_ok());
1163 }
1164
1165 #[test]
1166 fn extract_mtls_identity_from_cn() {
1167 let mut params = rcgen::CertificateParams::new(vec!["test-client.local".into()]).unwrap();
1169 params.distinguished_name = rcgen::DistinguishedName::new();
1170 params
1171 .distinguished_name
1172 .push(rcgen::DnType::CommonName, "test-client");
1173 let cert = params
1174 .self_signed(&rcgen::KeyPair::generate().unwrap())
1175 .unwrap();
1176 let der = cert.der();
1177
1178 let id = extract_mtls_identity(der, "ops").unwrap();
1179 assert_eq!(id.name, "test-client");
1180 assert_eq!(id.role, "ops");
1181 assert_eq!(id.method, AuthMethod::MtlsCertificate);
1182 }
1183
1184 #[test]
1185 fn extract_mtls_identity_falls_back_to_san() {
1186 let mut params =
1188 rcgen::CertificateParams::new(vec!["san-only.example.com".into()]).unwrap();
1189 params.distinguished_name = rcgen::DistinguishedName::new();
1190 let cert = params
1192 .self_signed(&rcgen::KeyPair::generate().unwrap())
1193 .unwrap();
1194 let der = cert.der();
1195
1196 let id = extract_mtls_identity(der, "viewer").unwrap();
1197 assert_eq!(id.name, "san-only.example.com");
1198 assert_eq!(id.role, "viewer");
1199 }
1200
1201 #[test]
1202 fn extract_mtls_identity_invalid_der() {
1203 assert!(extract_mtls_identity(b"not-a-cert", "viewer").is_none());
1204 }
1205
1206 use axum::{
1209 body::Body,
1210 http::{Request, StatusCode},
1211 };
1212 use tower::ServiceExt as _;
1213
1214 fn auth_router(state: Arc<AuthState>) -> axum::Router {
1215 axum::Router::new()
1216 .route("/mcp", axum::routing::post(|| async { "ok" }))
1217 .layer(axum::middleware::from_fn(move |req, next| {
1218 let s = Arc::clone(&state);
1219 auth_middleware(s, req, next)
1220 }))
1221 }
1222
1223 fn test_auth_state(keys: Vec<ApiKeyEntry>) -> Arc<AuthState> {
1224 Arc::new(AuthState {
1225 api_keys: ArcSwap::new(Arc::new(keys)),
1226 rate_limiter: None,
1227 pre_auth_limiter: None,
1228 #[cfg(feature = "oauth")]
1229 jwks_cache: None,
1230 seen_identities: Mutex::new(HashSet::new()),
1231 counters: AuthCounters::default(),
1232 })
1233 }
1234
1235 #[tokio::test]
1236 async fn middleware_rejects_no_credentials() {
1237 let state = test_auth_state(vec![]);
1238 let app = auth_router(Arc::clone(&state));
1239 let req = Request::builder()
1240 .method(axum::http::Method::POST)
1241 .uri("/mcp")
1242 .body(Body::empty())
1243 .unwrap();
1244 let resp = app.oneshot(req).await.unwrap();
1245 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1246 let challenge = resp
1247 .headers()
1248 .get(header::WWW_AUTHENTICATE)
1249 .unwrap()
1250 .to_str()
1251 .unwrap();
1252 assert!(challenge.contains("error=\"invalid_request\""));
1253
1254 let counters = state.counters_snapshot();
1255 assert_eq!(counters.failure_missing_credential, 1);
1256 }
1257
1258 #[tokio::test]
1259 async fn middleware_accepts_valid_bearer() {
1260 let (token, hash) = generate_api_key().unwrap();
1261 let keys = vec![ApiKeyEntry {
1262 name: "test-key".into(),
1263 hash,
1264 role: "ops".into(),
1265 expires_at: None,
1266 }];
1267 let state = test_auth_state(keys);
1268 let app = auth_router(Arc::clone(&state));
1269 let req = Request::builder()
1270 .method(axum::http::Method::POST)
1271 .uri("/mcp")
1272 .header("authorization", format!("Bearer {token}"))
1273 .body(Body::empty())
1274 .unwrap();
1275 let resp = app.oneshot(req).await.unwrap();
1276 assert_eq!(resp.status(), StatusCode::OK);
1277
1278 let counters = state.counters_snapshot();
1279 assert_eq!(counters.success_bearer, 1);
1280 }
1281
1282 #[tokio::test]
1283 async fn middleware_rejects_wrong_bearer() {
1284 let (_token, hash) = generate_api_key().unwrap();
1285 let keys = vec![ApiKeyEntry {
1286 name: "test-key".into(),
1287 hash,
1288 role: "ops".into(),
1289 expires_at: None,
1290 }];
1291 let state = test_auth_state(keys);
1292 let app = auth_router(Arc::clone(&state));
1293 let req = Request::builder()
1294 .method(axum::http::Method::POST)
1295 .uri("/mcp")
1296 .header("authorization", "Bearer wrong-token-here")
1297 .body(Body::empty())
1298 .unwrap();
1299 let resp = app.oneshot(req).await.unwrap();
1300 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1301 let challenge = resp
1302 .headers()
1303 .get(header::WWW_AUTHENTICATE)
1304 .unwrap()
1305 .to_str()
1306 .unwrap();
1307 assert!(challenge.contains("error=\"invalid_token\""));
1308
1309 let counters = state.counters_snapshot();
1310 assert_eq!(counters.failure_invalid_credential, 1);
1311 }
1312
1313 #[tokio::test]
1314 async fn middleware_rate_limits() {
1315 let state = Arc::new(AuthState {
1316 api_keys: ArcSwap::new(Arc::new(vec![])),
1317 rate_limiter: Some(build_rate_limiter(&RateLimitConfig {
1318 max_attempts_per_minute: 1,
1319 pre_auth_max_per_minute: None,
1320 ..Default::default()
1321 })),
1322 pre_auth_limiter: None,
1323 #[cfg(feature = "oauth")]
1324 jwks_cache: None,
1325 seen_identities: Mutex::new(HashSet::new()),
1326 counters: AuthCounters::default(),
1327 });
1328 let app = auth_router(state);
1329
1330 let req = Request::builder()
1332 .method(axum::http::Method::POST)
1333 .uri("/mcp")
1334 .body(Body::empty())
1335 .unwrap();
1336 let resp = app.clone().oneshot(req).await.unwrap();
1337 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1338
1339 }
1344
1345 #[test]
1351 fn rate_limit_semantics_failed_only() {
1352 let config = RateLimitConfig {
1353 max_attempts_per_minute: 3,
1354 pre_auth_max_per_minute: None,
1355 ..Default::default()
1356 };
1357 let limiter = build_rate_limiter(&config);
1358 let ip: IpAddr = "192.168.1.100".parse().unwrap();
1359
1360 assert!(
1362 limiter.check_key(&ip).is_ok(),
1363 "failure 1 should be allowed"
1364 );
1365 assert!(
1366 limiter.check_key(&ip).is_ok(),
1367 "failure 2 should be allowed"
1368 );
1369 assert!(
1370 limiter.check_key(&ip).is_ok(),
1371 "failure 3 should be allowed"
1372 );
1373 assert!(
1374 limiter.check_key(&ip).is_err(),
1375 "failure 4 should be blocked"
1376 );
1377
1378 }
1387
1388 #[test]
1393 fn pre_auth_default_multiplier_is_10x() {
1394 let config = RateLimitConfig {
1395 max_attempts_per_minute: 5,
1396 pre_auth_max_per_minute: None,
1397 ..Default::default()
1398 };
1399 let limiter = build_pre_auth_limiter(&config);
1400 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1401
1402 for i in 0..50 {
1404 assert!(
1405 limiter.check_key(&ip).is_ok(),
1406 "pre-auth attempt {i} (of expected 50) should be allowed under default 10x multiplier"
1407 );
1408 }
1409 assert!(
1411 limiter.check_key(&ip).is_err(),
1412 "pre-auth attempt 51 should be blocked (quota is 50, not unbounded)"
1413 );
1414 }
1415
1416 #[test]
1419 fn pre_auth_explicit_override_wins() {
1420 let config = RateLimitConfig {
1421 max_attempts_per_minute: 100, pre_auth_max_per_minute: Some(2), ..Default::default()
1424 };
1425 let limiter = build_pre_auth_limiter(&config);
1426 let ip: IpAddr = "10.0.0.2".parse().unwrap();
1427
1428 assert!(limiter.check_key(&ip).is_ok(), "attempt 1 allowed");
1429 assert!(limiter.check_key(&ip).is_ok(), "attempt 2 allowed");
1430 assert!(
1431 limiter.check_key(&ip).is_err(),
1432 "attempt 3 must be blocked (explicit override of 2 wins over 10x default of 1000)"
1433 );
1434 }
1435
1436 #[tokio::test]
1442 async fn pre_auth_gate_blocks_before_argon2_verification() {
1443 let (_token, hash) = generate_api_key().unwrap();
1444 let keys = vec![ApiKeyEntry {
1445 name: "test-key".into(),
1446 hash,
1447 role: "ops".into(),
1448 expires_at: None,
1449 }];
1450 let config = RateLimitConfig {
1451 max_attempts_per_minute: 100,
1452 pre_auth_max_per_minute: Some(1),
1453 ..Default::default()
1454 };
1455 let state = Arc::new(AuthState {
1456 api_keys: ArcSwap::new(Arc::new(keys)),
1457 rate_limiter: None,
1458 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1459 #[cfg(feature = "oauth")]
1460 jwks_cache: None,
1461 seen_identities: Mutex::new(HashSet::new()),
1462 counters: AuthCounters::default(),
1463 });
1464 let app = auth_router(Arc::clone(&state));
1465 let peer: SocketAddr = "10.0.0.10:54321".parse().unwrap();
1466
1467 let mut req1 = Request::builder()
1470 .method(axum::http::Method::POST)
1471 .uri("/mcp")
1472 .header("authorization", "Bearer obviously-not-a-real-token")
1473 .body(Body::empty())
1474 .unwrap();
1475 req1.extensions_mut().insert(ConnectInfo(peer));
1476 let resp1 = app.clone().oneshot(req1).await.unwrap();
1477 assert_eq!(
1478 resp1.status(),
1479 StatusCode::UNAUTHORIZED,
1480 "first attempt: gate has quota, falls through to bearer auth which fails with 401"
1481 );
1482
1483 let mut req2 = Request::builder()
1486 .method(axum::http::Method::POST)
1487 .uri("/mcp")
1488 .header("authorization", "Bearer also-not-a-real-token")
1489 .body(Body::empty())
1490 .unwrap();
1491 req2.extensions_mut().insert(ConnectInfo(peer));
1492 let resp2 = app.oneshot(req2).await.unwrap();
1493 assert_eq!(
1494 resp2.status(),
1495 StatusCode::TOO_MANY_REQUESTS,
1496 "second attempt from same IP: pre-auth gate must reject with 429"
1497 );
1498
1499 let counters = state.counters_snapshot();
1500 assert_eq!(
1501 counters.failure_pre_auth_gate, 1,
1502 "exactly one request must have been rejected by the pre-auth gate"
1503 );
1504 assert_eq!(
1508 counters.failure_invalid_credential, 1,
1509 "bearer verification must run exactly once (only the un-gated first request)"
1510 );
1511 }
1512
1513 #[tokio::test]
1520 async fn pre_auth_gate_does_not_throttle_mtls() {
1521 let config = RateLimitConfig {
1522 max_attempts_per_minute: 100,
1523 pre_auth_max_per_minute: Some(1), ..Default::default()
1525 };
1526 let state = Arc::new(AuthState {
1527 api_keys: ArcSwap::new(Arc::new(vec![])),
1528 rate_limiter: None,
1529 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
1530 #[cfg(feature = "oauth")]
1531 jwks_cache: None,
1532 seen_identities: Mutex::new(HashSet::new()),
1533 counters: AuthCounters::default(),
1534 });
1535 let app = auth_router(Arc::clone(&state));
1536 let peer: SocketAddr = "10.0.0.20:54321".parse().unwrap();
1537 let identity = AuthIdentity {
1538 name: "cn=test-client".into(),
1539 role: "viewer".into(),
1540 method: AuthMethod::MtlsCertificate,
1541 raw_token: None,
1542 sub: None,
1543 };
1544 let tls_info = TlsConnInfo::new(peer, Some(identity));
1545
1546 for i in 0..3 {
1547 let mut req = Request::builder()
1548 .method(axum::http::Method::POST)
1549 .uri("/mcp")
1550 .body(Body::empty())
1551 .unwrap();
1552 req.extensions_mut().insert(ConnectInfo(tls_info.clone()));
1553 let resp = app.clone().oneshot(req).await.unwrap();
1554 assert_eq!(
1555 resp.status(),
1556 StatusCode::OK,
1557 "mTLS request {i} must succeed: pre-auth gate must not apply to mTLS callers"
1558 );
1559 }
1560
1561 let counters = state.counters_snapshot();
1562 assert_eq!(
1563 counters.failure_pre_auth_gate, 0,
1564 "pre-auth gate counter must remain at zero: mTLS bypasses the gate"
1565 );
1566 assert_eq!(
1567 counters.success_mtls, 3,
1568 "all three mTLS requests must have been counted as successful"
1569 );
1570 }
1571}