1use std::{
10 collections::HashSet,
11 net::{IpAddr, SocketAddr},
12 num::NonZeroU32,
13 path::PathBuf,
14 sync::{
15 Arc, LazyLock, Mutex,
16 atomic::{AtomicU64, Ordering},
17 },
18 time::Duration,
19};
20
21use arc_swap::ArcSwap;
22use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
23use axum::{
24 body::Body,
25 extract::ConnectInfo,
26 http::{Request, header},
27 middleware::Next,
28 response::{IntoResponse, Response},
29};
30use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
31use secrecy::SecretString;
32use serde::Deserialize;
33use x509_parser::prelude::*;
34
35use crate::{bounded_limiter::BoundedKeyedLimiter, error::McpxError};
36
37#[derive(Clone)]
46#[non_exhaustive]
47pub struct AuthIdentity {
48 pub name: String,
50 pub role: String,
52 pub method: AuthMethod,
54 pub raw_token: Option<SecretString>,
60 pub sub: Option<String>,
63}
64
65impl std::fmt::Debug for AuthIdentity {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("AuthIdentity")
70 .field("name", &self.name)
71 .field("role", &self.role)
72 .field("method", &self.method)
73 .field(
74 "raw_token",
75 &if self.raw_token.is_some() {
76 "<redacted>"
77 } else {
78 "<none>"
79 },
80 )
81 .field(
82 "sub",
83 &if self.sub.is_some() {
84 "<redacted>"
85 } else {
86 "<none>"
87 },
88 )
89 .finish()
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95#[non_exhaustive]
96pub enum AuthMethod {
97 BearerToken,
99 MtlsCertificate,
101 OAuthJwt,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106enum AuthFailureClass {
107 MissingCredential,
108 InvalidCredential,
109 #[cfg_attr(not(feature = "oauth"), allow(dead_code))]
110 ExpiredCredential,
111 RateLimited,
113 PreAuthGate,
116}
117
118impl AuthFailureClass {
119 fn as_str(self) -> &'static str {
120 match self {
121 Self::MissingCredential => "missing_credential",
122 Self::InvalidCredential => "invalid_credential",
123 Self::ExpiredCredential => "expired_credential",
124 Self::RateLimited => "rate_limited",
125 Self::PreAuthGate => "pre_auth_gate",
126 }
127 }
128
129 fn bearer_error(self) -> (&'static str, &'static str) {
130 match self {
131 Self::MissingCredential => (
132 "invalid_request",
133 "missing bearer token or mTLS client certificate",
134 ),
135 Self::InvalidCredential => ("invalid_token", "token is invalid"),
136 Self::ExpiredCredential => ("invalid_token", "token is expired"),
137 Self::RateLimited => ("invalid_request", "too many failed authentication attempts"),
138 Self::PreAuthGate => (
139 "invalid_request",
140 "too many unauthenticated requests from this source",
141 ),
142 }
143 }
144
145 fn response_body(self) -> &'static str {
146 match self {
147 Self::MissingCredential => "unauthorized: missing credential",
148 Self::InvalidCredential => "unauthorized: invalid credential",
149 Self::ExpiredCredential => "unauthorized: expired credential",
150 Self::RateLimited => "rate limited",
151 Self::PreAuthGate => "rate limited (pre-auth)",
152 }
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
158#[non_exhaustive]
159pub struct AuthCountersSnapshot {
160 pub success_mtls: u64,
162 pub success_bearer: u64,
164 pub success_oauth_jwt: u64,
166 pub failure_missing_credential: u64,
168 pub failure_invalid_credential: u64,
170 pub failure_expired_credential: u64,
172 pub failure_rate_limited: u64,
174 pub failure_pre_auth_gate: u64,
177}
178
179#[derive(Debug, Default)]
181pub(crate) struct AuthCounters {
182 success_mtls: AtomicU64,
183 success_bearer: AtomicU64,
184 success_oauth_jwt: AtomicU64,
185 failure_missing_credential: AtomicU64,
186 failure_invalid_credential: AtomicU64,
187 failure_expired_credential: AtomicU64,
188 failure_rate_limited: AtomicU64,
189 failure_pre_auth_gate: AtomicU64,
190}
191
192impl AuthCounters {
193 fn record_success(&self, method: AuthMethod) {
194 match method {
195 AuthMethod::MtlsCertificate => {
196 self.success_mtls.fetch_add(1, Ordering::Relaxed);
197 }
198 AuthMethod::BearerToken => {
199 self.success_bearer.fetch_add(1, Ordering::Relaxed);
200 }
201 AuthMethod::OAuthJwt => {
202 self.success_oauth_jwt.fetch_add(1, Ordering::Relaxed);
203 }
204 }
205 }
206
207 fn record_failure(&self, class: AuthFailureClass) {
208 match class {
209 AuthFailureClass::MissingCredential => {
210 self.failure_missing_credential
211 .fetch_add(1, Ordering::Relaxed);
212 }
213 AuthFailureClass::InvalidCredential => {
214 self.failure_invalid_credential
215 .fetch_add(1, Ordering::Relaxed);
216 }
217 AuthFailureClass::ExpiredCredential => {
218 self.failure_expired_credential
219 .fetch_add(1, Ordering::Relaxed);
220 }
221 AuthFailureClass::RateLimited => {
222 self.failure_rate_limited.fetch_add(1, Ordering::Relaxed);
223 }
224 AuthFailureClass::PreAuthGate => {
225 self.failure_pre_auth_gate.fetch_add(1, Ordering::Relaxed);
226 }
227 }
228 }
229
230 fn snapshot(&self) -> AuthCountersSnapshot {
231 AuthCountersSnapshot {
232 success_mtls: self.success_mtls.load(Ordering::Relaxed),
233 success_bearer: self.success_bearer.load(Ordering::Relaxed),
234 success_oauth_jwt: self.success_oauth_jwt.load(Ordering::Relaxed),
235 failure_missing_credential: self.failure_missing_credential.load(Ordering::Relaxed),
236 failure_invalid_credential: self.failure_invalid_credential.load(Ordering::Relaxed),
237 failure_expired_credential: self.failure_expired_credential.load(Ordering::Relaxed),
238 failure_rate_limited: self.failure_rate_limited.load(Ordering::Relaxed),
239 failure_pre_auth_gate: self.failure_pre_auth_gate.load(Ordering::Relaxed),
240 }
241 }
242}
243
244#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
256#[non_exhaustive]
257pub struct RfcTimestamp(chrono::DateTime<chrono::FixedOffset>);
258
259impl RfcTimestamp {
260 pub fn parse(s: &str) -> Result<Self, chrono::ParseError> {
268 chrono::DateTime::parse_from_rfc3339(s).map(Self)
269 }
270
271 #[must_use]
273 pub fn as_datetime(&self) -> &chrono::DateTime<chrono::FixedOffset> {
274 &self.0
275 }
276
277 #[must_use]
279 pub fn into_inner(self) -> chrono::DateTime<chrono::FixedOffset> {
280 self.0
281 }
282}
283
284impl std::fmt::Display for RfcTimestamp {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 write!(f, "{}", self.0.to_rfc3339())
288 }
289}
290
291impl std::fmt::Debug for RfcTimestamp {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 write!(f, "{}", self.0.to_rfc3339())
298 }
299}
300
301impl<'de> Deserialize<'de> for RfcTimestamp {
302 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
303 where
304 D: serde::Deserializer<'de>,
305 {
306 let s = String::deserialize(deserializer)?;
310 Self::parse(&s).map_err(serde::de::Error::custom)
311 }
312}
313
314impl serde::Serialize for RfcTimestamp {
315 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
316 where
317 S: serde::Serializer,
318 {
319 serializer.serialize_str(&self.0.to_rfc3339())
320 }
321}
322
323impl From<chrono::DateTime<chrono::FixedOffset>> for RfcTimestamp {
324 fn from(value: chrono::DateTime<chrono::FixedOffset>) -> Self {
325 Self(value)
326 }
327}
328
329#[derive(Clone, Deserialize)]
336#[non_exhaustive]
337pub struct ApiKeyEntry {
338 pub name: String,
340 pub hash: String,
342 pub role: String,
344 pub expires_at: Option<RfcTimestamp>,
349}
350
351impl std::fmt::Debug for ApiKeyEntry {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 f.debug_struct("ApiKeyEntry")
356 .field("name", &self.name)
357 .field("hash", &"<redacted>")
358 .field("role", &self.role)
359 .field("expires_at", &self.expires_at)
360 .finish()
361 }
362}
363
364impl ApiKeyEntry {
365 #[must_use]
367 pub fn new(name: impl Into<String>, hash: impl Into<String>, role: impl Into<String>) -> Self {
368 Self {
369 name: name.into(),
370 hash: hash.into(),
371 role: role.into(),
372 expires_at: None,
373 }
374 }
375
376 #[must_use]
381 pub fn with_expiry(mut self, expires_at: RfcTimestamp) -> Self {
382 self.expires_at = Some(expires_at);
383 self
384 }
385
386 pub fn try_with_expiry(
394 mut self,
395 expires_at: impl AsRef<str>,
396 ) -> Result<Self, chrono::ParseError> {
397 self.expires_at = Some(RfcTimestamp::parse(expires_at.as_ref())?);
398 Ok(self)
399 }
400}
401
402#[derive(Debug, Clone, Deserialize)]
404#[allow(
405 clippy::struct_excessive_bools,
406 reason = "mTLS CRL behavior is intentionally configured as independent booleans"
407)]
408#[non_exhaustive]
409pub struct MtlsConfig {
410 pub ca_cert_path: PathBuf,
412 #[serde(default)]
415 pub required: bool,
416 #[serde(default = "default_mtls_role")]
419 pub default_role: String,
420 #[serde(default = "default_true")]
423 pub crl_enabled: bool,
424 #[serde(default, with = "humantime_serde::option")]
427 pub crl_refresh_interval: Option<Duration>,
428 #[serde(default = "default_crl_fetch_timeout", with = "humantime_serde")]
430 pub crl_fetch_timeout: Duration,
431 #[serde(default = "default_crl_stale_grace", with = "humantime_serde")]
434 pub crl_stale_grace: Duration,
435 #[serde(default)]
438 pub crl_deny_on_unavailable: bool,
439 #[serde(default)]
441 pub crl_end_entity_only: bool,
442 #[serde(default = "default_true")]
451 pub crl_allow_http: bool,
452 #[serde(default = "default_true")]
454 pub crl_enforce_expiration: bool,
455 #[serde(default = "default_crl_max_concurrent_fetches")]
461 pub crl_max_concurrent_fetches: usize,
462 #[serde(default = "default_crl_max_response_bytes")]
466 pub crl_max_response_bytes: u64,
467 #[serde(default = "default_crl_discovery_rate_per_min")]
483 pub crl_discovery_rate_per_min: u32,
484 #[serde(default = "default_crl_max_host_semaphores")]
493 pub crl_max_host_semaphores: usize,
494 #[serde(default = "default_crl_max_seen_urls")]
498 pub crl_max_seen_urls: usize,
499 #[serde(default = "default_crl_max_cache_entries")]
503 pub crl_max_cache_entries: usize,
504}
505
506fn default_mtls_role() -> String {
507 "viewer".into()
508}
509
510const fn default_true() -> bool {
511 true
512}
513
514const fn default_crl_fetch_timeout() -> Duration {
515 Duration::from_secs(30)
516}
517
518const fn default_crl_stale_grace() -> Duration {
519 Duration::from_hours(24)
520}
521
522const fn default_crl_max_concurrent_fetches() -> usize {
523 4
524}
525
526const fn default_crl_max_response_bytes() -> u64 {
527 5 * 1024 * 1024
528}
529
530const fn default_crl_discovery_rate_per_min() -> u32 {
531 60
532}
533
534const fn default_crl_max_host_semaphores() -> usize {
535 1024
536}
537
538const fn default_crl_max_seen_urls() -> usize {
539 4096
540}
541
542const fn default_crl_max_cache_entries() -> usize {
543 1024
544}
545
546#[derive(Debug, Clone, Deserialize)]
561#[non_exhaustive]
562pub struct RateLimitConfig {
563 #[serde(default = "default_max_attempts")]
566 pub max_attempts_per_minute: u32,
567 #[serde(default)]
575 pub pre_auth_max_per_minute: Option<u32>,
576 #[serde(default = "default_max_tracked_keys")]
581 pub max_tracked_keys: usize,
582 #[serde(default = "default_idle_eviction", with = "humantime_serde")]
585 pub idle_eviction: Duration,
586 #[serde(default)]
593 pub burst: Option<u32>,
594 #[serde(default)]
600 pub pre_auth_burst: Option<u32>,
601}
602
603impl Default for RateLimitConfig {
604 fn default() -> Self {
605 Self {
606 max_attempts_per_minute: default_max_attempts(),
607 pre_auth_max_per_minute: None,
608 max_tracked_keys: default_max_tracked_keys(),
609 idle_eviction: default_idle_eviction(),
610 burst: None,
611 pre_auth_burst: None,
612 }
613 }
614}
615
616impl RateLimitConfig {
617 #[must_use]
621 pub fn new(max_attempts_per_minute: u32) -> Self {
622 Self {
623 max_attempts_per_minute,
624 ..Self::default()
625 }
626 }
627
628 #[must_use]
631 pub fn with_pre_auth_max_per_minute(mut self, quota: u32) -> Self {
632 self.pre_auth_max_per_minute = Some(quota);
633 self
634 }
635
636 #[must_use]
638 pub fn with_max_tracked_keys(mut self, max: usize) -> Self {
639 self.max_tracked_keys = max;
640 self
641 }
642
643 #[must_use]
645 pub fn with_idle_eviction(mut self, idle: Duration) -> Self {
646 self.idle_eviction = idle;
647 self
648 }
649
650 #[must_use]
653 pub fn with_burst(mut self, burst: u32) -> Self {
654 self.burst = Some(burst);
655 self
656 }
657
658 #[must_use]
661 pub fn with_pre_auth_burst(mut self, burst: u32) -> Self {
662 self.pre_auth_burst = Some(burst);
663 self
664 }
665}
666
667fn default_max_attempts() -> u32 {
668 30
669}
670
671fn default_max_tracked_keys() -> usize {
672 10_000
673}
674
675fn default_idle_eviction() -> Duration {
676 Duration::from_mins(15)
677}
678
679#[derive(Debug, Clone, Default, Deserialize)]
681#[non_exhaustive]
682pub struct AuthConfig {
683 #[serde(default)]
685 pub enabled: bool,
686 #[serde(default)]
688 pub api_keys: Vec<ApiKeyEntry>,
689 pub mtls: Option<MtlsConfig>,
691 pub rate_limit: Option<RateLimitConfig>,
693 #[cfg(feature = "oauth")]
695 pub oauth: Option<crate::oauth::OAuthConfig>,
696}
697
698impl AuthConfig {
699 #[must_use]
701 pub fn with_keys(keys: Vec<ApiKeyEntry>) -> Self {
702 Self {
703 enabled: true,
704 api_keys: keys,
705 mtls: None,
706 rate_limit: None,
707 #[cfg(feature = "oauth")]
708 oauth: None,
709 }
710 }
711
712 #[must_use]
714 pub fn with_rate_limit(mut self, rate_limit: RateLimitConfig) -> Self {
715 self.rate_limit = Some(rate_limit);
716 self
717 }
718}
719
720#[derive(Debug, Clone, serde::Serialize)]
724#[non_exhaustive]
725pub struct ApiKeySummary {
726 pub name: String,
728 pub role: String,
730 pub expires_at: Option<RfcTimestamp>,
733}
734
735#[derive(Debug, Clone, serde::Serialize)]
737#[allow(
738 clippy::struct_excessive_bools,
739 reason = "this is a flat summary of independent auth-method booleans"
740)]
741#[non_exhaustive]
742pub struct AuthConfigSummary {
743 pub enabled: bool,
745 pub bearer: bool,
747 pub mtls: bool,
749 pub oauth: bool,
751 pub api_keys: Vec<ApiKeySummary>,
753}
754
755impl AuthConfig {
756 #[must_use]
758 pub fn summary(&self) -> AuthConfigSummary {
759 AuthConfigSummary {
760 enabled: self.enabled,
761 bearer: !self.api_keys.is_empty(),
762 mtls: self.mtls.is_some(),
763 #[cfg(feature = "oauth")]
764 oauth: self.oauth.is_some(),
765 #[cfg(not(feature = "oauth"))]
766 oauth: false,
767 api_keys: self
768 .api_keys
769 .iter()
770 .map(|k| ApiKeySummary {
771 name: k.name.clone(),
772 role: k.role.clone(),
773 expires_at: k.expires_at,
774 })
775 .collect(),
776 }
777 }
778}
779
780pub(crate) type KeyedLimiter = BoundedKeyedLimiter<IpAddr>;
783
784#[derive(Clone, Debug)]
794#[non_exhaustive]
795pub(crate) struct TlsConnInfo {
796 pub addr: SocketAddr,
798 pub identity: Option<AuthIdentity>,
801}
802
803impl TlsConnInfo {
804 #[must_use]
806 pub(crate) const fn new(addr: SocketAddr, identity: Option<AuthIdentity>) -> Self {
807 Self { addr, identity }
808 }
809}
810
811const DEFAULT_SEEN_IDENTITY_CAP: usize = 4096;
819
820pub(crate) struct SeenIdentitySet {
840 inner: Mutex<SeenInner>,
841}
842
843struct SeenInner {
844 set: HashSet<String>,
845 order: std::collections::VecDeque<String>,
850 cap: usize,
851}
852
853impl SeenIdentitySet {
854 #[must_use]
856 pub(crate) fn new() -> Self {
857 Self::with_cap(DEFAULT_SEEN_IDENTITY_CAP)
858 }
859
860 #[must_use]
863 pub(crate) fn with_cap(cap: usize) -> Self {
864 let cap = cap.max(1);
865 Self {
866 inner: Mutex::new(SeenInner {
867 set: HashSet::with_capacity(cap.min(64)),
868 order: std::collections::VecDeque::with_capacity(cap.min(64)),
869 cap,
870 }),
871 }
872 }
873
874 pub(crate) fn insert_is_first(&self, name: &str) -> bool {
881 let mut guard = self
887 .inner
888 .lock()
889 .unwrap_or_else(std::sync::PoisonError::into_inner);
890
891 if guard.set.contains(name) {
892 return false;
893 }
894 if guard.set.len() >= guard.cap
897 && let Some(evicted) = guard.order.pop_front()
898 {
899 guard.set.remove(&evicted);
900 }
901 let owned = name.to_owned();
902 guard.set.insert(owned.clone());
903 guard.order.push_back(owned);
904 true
905 }
906
907 #[cfg(test)]
909 pub(crate) fn len(&self) -> usize {
910 self.inner
911 .lock()
912 .unwrap_or_else(std::sync::PoisonError::into_inner)
913 .set
914 .len()
915 }
916}
917
918impl Default for SeenIdentitySet {
919 fn default() -> Self {
920 Self::new()
921 }
922}
923
924#[allow(
929 missing_debug_implementations,
930 reason = "contains governor RateLimiter and JwksCache without Debug impls"
931)]
932#[non_exhaustive]
933pub(crate) struct AuthState {
934 pub api_keys: ArcSwap<Vec<ApiKeyEntry>>,
936 pub rate_limiter: Option<Arc<KeyedLimiter>>,
938 pub pre_auth_limiter: Option<Arc<KeyedLimiter>>,
941 #[cfg(feature = "oauth")]
942 pub jwks_cache: Option<Arc<crate::oauth::JwksCache>>,
944 pub seen_identities: SeenIdentitySet,
949 pub counters: AuthCounters,
951}
952
953impl AuthState {
954 pub(crate) fn reload_keys(&self, keys: Vec<ApiKeyEntry>) {
960 let count = keys.len();
961 self.api_keys.store(Arc::new(keys));
962 tracing::info!(keys = count, "API keys reloaded");
963 }
964
965 #[must_use]
967 pub(crate) fn counters_snapshot(&self) -> AuthCountersSnapshot {
968 self.counters.snapshot()
969 }
970
971 #[must_use]
973 pub(crate) fn api_key_summaries(&self) -> Vec<ApiKeySummary> {
974 self.api_keys
975 .load()
976 .iter()
977 .map(|k| ApiKeySummary {
978 name: k.name.clone(),
979 role: k.role.clone(),
980 expires_at: k.expires_at,
981 })
982 .collect()
983 }
984
985 fn log_auth(&self, id: &AuthIdentity, method: &str) {
993 self.counters.record_success(id.method);
994 let first = self.seen_identities.insert_is_first(&id.name);
995 if first {
996 tracing::info!(name = %id.name, role = %id.role, "{method} authenticated");
997 } else {
998 tracing::debug!(name = %id.name, role = %id.role, "{method} authenticated");
999 }
1000 }
1001}
1002
1003const DEFAULT_AUTH_RATE: NonZeroU32 = NonZeroU32::new(30).unwrap();
1006
1007fn apply_burst(quota: governor::Quota, burst: Option<u32>) -> governor::Quota {
1011 match burst.and_then(NonZeroU32::new) {
1012 Some(b) => quota.allow_burst(b),
1013 None => quota,
1014 }
1015}
1016
1017#[must_use]
1019pub(crate) fn build_rate_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
1020 let quota = governor::Quota::per_minute(
1021 NonZeroU32::new(config.max_attempts_per_minute).unwrap_or(DEFAULT_AUTH_RATE),
1022 );
1023 let quota = apply_burst(quota, config.burst);
1024 Arc::new(BoundedKeyedLimiter::new(
1025 quota,
1026 config.max_tracked_keys,
1027 config.idle_eviction,
1028 ))
1029}
1030
1031#[must_use]
1038pub(crate) fn build_pre_auth_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
1039 let resolved = config.pre_auth_max_per_minute.unwrap_or_else(|| {
1040 config
1041 .max_attempts_per_minute
1042 .saturating_mul(PRE_AUTH_DEFAULT_MULTIPLIER)
1043 });
1044 let quota =
1045 governor::Quota::per_minute(NonZeroU32::new(resolved).unwrap_or(DEFAULT_PRE_AUTH_RATE));
1046 let quota = apply_burst(quota, config.pre_auth_burst);
1047 Arc::new(BoundedKeyedLimiter::new(
1048 quota,
1049 config.max_tracked_keys,
1050 config.idle_eviction,
1051 ))
1052}
1053
1054const PRE_AUTH_DEFAULT_MULTIPLIER: u32 = 10;
1057
1058const DEFAULT_PRE_AUTH_RATE: NonZeroU32 = NonZeroU32::new(300).unwrap();
1062
1063#[must_use]
1068pub fn extract_mtls_identity(cert_der: &[u8], default_role: &str) -> Option<AuthIdentity> {
1069 let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
1070
1071 let cn = cert
1073 .subject()
1074 .iter_common_name()
1075 .next()
1076 .and_then(|attr| attr.as_str().ok())
1077 .map(String::from);
1078
1079 let name = cn.or_else(|| {
1081 cert.subject_alternative_name()
1082 .ok()
1083 .flatten()
1084 .and_then(|san| {
1085 #[allow(
1086 clippy::wildcard_enum_match_arm,
1087 reason = "x509-parser GeneralName is a large external enum; only DNSName is meaningful here"
1088 )]
1089 san.value.general_names.iter().find_map(|gn| match gn {
1090 GeneralName::DNSName(dns) => Some((*dns).to_owned()),
1091 _ => None,
1092 })
1093 })
1094 })?;
1095
1096 if !name
1098 .chars()
1099 .all(|c| c.is_alphanumeric() || matches!(c, '-' | '.' | '_' | '@'))
1100 {
1101 tracing::warn!(cn = %name, "mTLS identity rejected: invalid characters in CN/SAN");
1102 return None;
1103 }
1104
1105 Some(AuthIdentity {
1106 name,
1107 role: default_role.to_owned(),
1108 method: AuthMethod::MtlsCertificate,
1109 raw_token: None,
1110 sub: None,
1111 })
1112}
1113
1114fn extract_bearer(value: &str) -> Option<&str> {
1129 let (scheme, rest) = value.split_once(' ')?;
1130 if scheme.eq_ignore_ascii_case("Bearer") {
1131 let token = rest.trim_start_matches(' ');
1132 if token.is_empty() { None } else { Some(token) }
1133 } else {
1134 None
1135 }
1136}
1137
1138#[must_use]
1167pub fn verify_bearer_token(token: &str, keys: &[ApiKeyEntry]) -> Option<AuthIdentity> {
1168 use subtle::ConstantTimeEq as _;
1169
1170 let now = chrono::Utc::now();
1171 #[allow(
1172 clippy::expect_used,
1173 reason = "DUMMY_PHC_HASH is a static LazyLock built from a fixed Argon2id PHC string by construction; PasswordHash::new on it is infallible. See DUMMY_PHC_HASH definition."
1174 )]
1175 let dummy_hash = PasswordHash::new(&DUMMY_PHC_HASH)
1176 .expect("DUMMY_PHC_HASH is a valid Argon2id PHC string by construction");
1177
1178 let mut matched_index: usize = usize::MAX;
1179 let mut any_match: u8 = 0;
1180
1181 for (idx, key) in keys.iter().enumerate() {
1182 let expired = key.expires_at.is_some_and(|exp| exp.as_datetime() < &now);
1183
1184 let real_hash = PasswordHash::new(&key.hash);
1185 let verify_against = match (&real_hash, expired, any_match) {
1186 (Ok(h), false, 0) => h,
1187 _ => &dummy_hash,
1188 };
1189
1190 let slot_ok = u8::from(
1191 Argon2::default()
1192 .verify_password(token.as_bytes(), verify_against)
1193 .is_ok(),
1194 );
1195
1196 let real_match = slot_ok & u8::from(!expired) & u8::from(real_hash.is_ok());
1197 let first_real_match = real_match & (1 - any_match);
1198 if first_real_match.ct_eq(&1).into() {
1199 matched_index = idx;
1200 }
1201 any_match |= real_match;
1202 }
1203
1204 if any_match == 0 {
1205 return None;
1206 }
1207 let key = keys.get(matched_index)?;
1208 Some(AuthIdentity {
1209 name: key.name.clone(),
1210 role: key.role.clone(),
1211 method: AuthMethod::BearerToken,
1212 raw_token: None,
1213 sub: None,
1214 })
1215}
1216
1217static DUMMY_PHC_HASH: LazyLock<String> = LazyLock::new(|| {
1230 #[allow(
1232 clippy::expect_used,
1233 reason = "fixed 22-char base64 ('AAAA...') decodes to a valid 16-byte salt; SaltString::from_b64 is infallible on this literal"
1234 )]
1235 let salt = SaltString::from_b64("AAAAAAAAAAAAAAAAAAAAAA")
1236 .expect("fixed 16-byte base64 salt is well-formed");
1237 #[allow(
1238 clippy::expect_used,
1239 reason = "Argon2::default() with a fixed plaintext and a well-formed salt is infallible; only fails on bad params/salt"
1240 )]
1241 Argon2::default()
1242 .hash_password(b"rmcp-server-kit-dummy", &salt)
1243 .expect("Argon2 default params hash a fixed plaintext")
1244 .to_string()
1245});
1246
1247pub fn generate_api_key() -> Result<(String, String), McpxError> {
1257 let mut token_bytes = [0u8; 32];
1258 rand::fill(&mut token_bytes);
1259 let token = URL_SAFE_NO_PAD.encode(token_bytes);
1260
1261 let mut salt_bytes = [0u8; 16];
1263 rand::fill(&mut salt_bytes);
1264 let salt = SaltString::encode_b64(&salt_bytes)
1265 .map_err(|e| McpxError::Auth(format!("salt encoding failed: {e}")))?;
1266 let hash = Argon2::default()
1267 .hash_password(token.as_bytes(), &salt)
1268 .map_err(|e| McpxError::Auth(format!("argon2id hashing failed: {e}")))?
1269 .to_string();
1270
1271 Ok((token, hash))
1272}
1273
1274fn build_www_authenticate_value(
1275 advertise_resource_metadata: bool,
1276 failure: AuthFailureClass,
1277) -> String {
1278 let (error, error_description) = failure.bearer_error();
1279 if advertise_resource_metadata {
1280 return format!(
1281 "Bearer resource_metadata=\"/.well-known/oauth-protected-resource\", error=\"{error}\", error_description=\"{error_description}\""
1282 );
1283 }
1284 format!("Bearer error=\"{error}\", error_description=\"{error_description}\"")
1285}
1286
1287fn auth_method_label(method: AuthMethod) -> &'static str {
1288 match method {
1289 AuthMethod::MtlsCertificate => "mTLS",
1290 AuthMethod::BearerToken => "bearer token",
1291 AuthMethod::OAuthJwt => "OAuth JWT",
1292 }
1293}
1294
1295#[cfg_attr(not(feature = "oauth"), allow(unused_variables))]
1296fn unauthorized_response(state: &AuthState, failure_class: AuthFailureClass) -> Response {
1297 #[cfg(feature = "oauth")]
1298 let advertise_resource_metadata = state.jwks_cache.is_some();
1299 #[cfg(not(feature = "oauth"))]
1300 let advertise_resource_metadata = false;
1301
1302 let challenge = build_www_authenticate_value(advertise_resource_metadata, failure_class);
1303 (
1304 axum::http::StatusCode::UNAUTHORIZED,
1305 [(header::WWW_AUTHENTICATE, challenge)],
1306 failure_class.response_body(),
1307 )
1308 .into_response()
1309}
1310
1311async fn authenticate_bearer_identity(
1312 state: &AuthState,
1313 token: &str,
1314) -> Result<AuthIdentity, AuthFailureClass> {
1315 let mut failure_class = AuthFailureClass::MissingCredential;
1316
1317 #[cfg(feature = "oauth")]
1318 if let Some(ref cache) = state.jwks_cache
1319 && crate::oauth::looks_like_jwt(token)
1320 {
1321 match cache.validate_token_with_reason(token).await {
1322 Ok(mut id) => {
1323 id.raw_token = Some(SecretString::from(token.to_owned()));
1324 return Ok(id);
1325 }
1326 Err(crate::oauth::JwtValidationFailure::Expired) => {
1327 failure_class = AuthFailureClass::ExpiredCredential;
1328 }
1329 Err(crate::oauth::JwtValidationFailure::Invalid) => {
1330 failure_class = AuthFailureClass::InvalidCredential;
1331 }
1332 }
1333 }
1334
1335 let token = token.to_owned();
1336 let keys = state.api_keys.load_full(); let identity = tokio::task::spawn_blocking(move || verify_bearer_token(&token, &keys))
1340 .await
1341 .ok()
1342 .flatten();
1343
1344 if let Some(id) = identity {
1345 return Ok(id);
1346 }
1347
1348 if failure_class == AuthFailureClass::MissingCredential {
1349 failure_class = AuthFailureClass::InvalidCredential;
1350 }
1351
1352 Err(failure_class)
1353}
1354
1355fn pre_auth_gate(state: &AuthState, client_ip: Option<IpAddr>) -> Option<Response> {
1366 let limiter = state.pre_auth_limiter.as_ref()?;
1367 let ip = client_ip?;
1368 let Err(wait) = limiter.check_key_wait(&ip) else {
1369 return None;
1370 };
1371 state.counters.record_failure(AuthFailureClass::PreAuthGate);
1372 tracing::warn!(
1373 %ip,
1374 "auth rate limited by pre-auth gate (request rejected before credential verification)"
1375 );
1376 Some(
1377 McpxError::RateLimitedFor {
1378 message: "too many unauthenticated requests from this source".into(),
1379 retry_after: wait,
1380 }
1381 .into_response(),
1382 )
1383}
1384
1385pub(crate) async fn auth_middleware(
1394 state: Arc<AuthState>,
1395 req: Request<Body>,
1396 next: Next,
1397) -> Response {
1398 let tls_info = req.extensions().get::<ConnectInfo<TlsConnInfo>>().cloned();
1404 let client_ip = crate::transport::limiter_client_ip(req.extensions());
1405
1406 if let Some(id) = tls_info.and_then(|ci| ci.0.identity) {
1413 state.log_auth(&id, "mTLS");
1414 let mut req = req;
1415 req.extensions_mut().insert(id);
1416 return next.run(req).await;
1417 }
1418
1419 if let Some(blocked) = pre_auth_gate(&state, client_ip) {
1423 #[cfg(feature = "metrics")]
1424 crate::metrics::record_rate_limit_deny(req.extensions(), "auth_pre");
1425 return blocked;
1426 }
1427
1428 let failure_class = if let Some(value) = req.headers().get(header::AUTHORIZATION) {
1429 match value.to_str().ok().and_then(extract_bearer) {
1430 Some(token) => match authenticate_bearer_identity(&state, token).await {
1431 Ok(id) => {
1432 state.log_auth(&id, auth_method_label(id.method));
1433 let mut req = req;
1434 req.extensions_mut().insert(id);
1435 return next.run(req).await;
1436 }
1437 Err(class) => class,
1438 },
1439 None => AuthFailureClass::InvalidCredential,
1440 }
1441 } else {
1442 AuthFailureClass::MissingCredential
1443 };
1444
1445 tracing::warn!(failure_class = %failure_class.as_str(), "auth failed");
1446
1447 if let (Some(limiter), Some(ip)) = (&state.rate_limiter, client_ip)
1450 && let Err(wait) = limiter.check_key_wait(&ip)
1451 {
1452 state.counters.record_failure(AuthFailureClass::RateLimited);
1453 #[cfg(feature = "metrics")]
1454 crate::metrics::record_rate_limit_deny(req.extensions(), "auth_post");
1455 tracing::warn!(%ip, "auth rate limited after repeated failures");
1456 return McpxError::RateLimitedFor {
1457 message: "too many failed authentication attempts".into(),
1458 retry_after: wait,
1459 }
1460 .into_response();
1461 }
1462
1463 state.counters.record_failure(failure_class);
1464 unauthorized_response(&state, failure_class)
1465}
1466
1467#[cfg(test)]
1468mod tests {
1469 use super::*;
1470
1471 #[test]
1472 fn generate_and_verify_api_key() {
1473 let (token, hash) = generate_api_key().unwrap();
1474
1475 assert_eq!(token.len(), 43);
1477
1478 assert!(hash.starts_with("$argon2id$"));
1480
1481 let keys = vec![ApiKeyEntry {
1483 name: "test".into(),
1484 hash,
1485 role: "viewer".into(),
1486 expires_at: None,
1487 }];
1488 let id = verify_bearer_token(&token, &keys);
1489 assert!(id.is_some());
1490 let id = id.unwrap();
1491 assert_eq!(id.name, "test");
1492 assert_eq!(id.role, "viewer");
1493 assert_eq!(id.method, AuthMethod::BearerToken);
1494 }
1495
1496 #[test]
1497 fn wrong_token_rejected() {
1498 let (_token, hash) = generate_api_key().unwrap();
1499 let keys = vec![ApiKeyEntry {
1500 name: "test".into(),
1501 hash,
1502 role: "viewer".into(),
1503 expires_at: None,
1504 }];
1505 assert!(verify_bearer_token("wrong-token", &keys).is_none());
1506 }
1507
1508 #[test]
1509 fn expired_key_rejected() {
1510 let (token, hash) = generate_api_key().unwrap();
1511 let keys = vec![ApiKeyEntry {
1512 name: "test".into(),
1513 hash,
1514 role: "viewer".into(),
1515 expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1516 }];
1517 assert!(verify_bearer_token(&token, &keys).is_none());
1518 }
1519
1520 #[test]
1521 fn match_in_last_slot_still_authenticates() {
1522 let (token, hash) = generate_api_key().unwrap();
1523 let (_other_token, other_hash) = generate_api_key().unwrap();
1524 let keys = vec![
1525 ApiKeyEntry {
1526 name: "first".into(),
1527 hash: other_hash.clone(),
1528 role: "viewer".into(),
1529 expires_at: None,
1530 },
1531 ApiKeyEntry {
1532 name: "second".into(),
1533 hash: other_hash,
1534 role: "viewer".into(),
1535 expires_at: None,
1536 },
1537 ApiKeyEntry {
1538 name: "match".into(),
1539 hash,
1540 role: "ops".into(),
1541 expires_at: None,
1542 },
1543 ];
1544 let id = verify_bearer_token(&token, &keys).expect("last-slot match must authenticate");
1545 assert_eq!(id.name, "match");
1546 assert_eq!(id.role, "ops");
1547 }
1548
1549 #[test]
1550 fn expired_slot_before_valid_match_does_not_short_circuit() {
1551 let (token, hash) = generate_api_key().unwrap();
1552 let (_, other_hash) = generate_api_key().unwrap();
1553 let keys = vec![
1554 ApiKeyEntry {
1555 name: "expired".into(),
1556 hash: other_hash,
1557 role: "viewer".into(),
1558 expires_at: Some(RfcTimestamp::parse("2020-01-01T00:00:00Z").unwrap()),
1559 },
1560 ApiKeyEntry {
1561 name: "valid".into(),
1562 hash,
1563 role: "ops".into(),
1564 expires_at: None,
1565 },
1566 ];
1567 let id = verify_bearer_token(&token, &keys)
1568 .expect("valid slot following an expired slot must authenticate");
1569 assert_eq!(id.name, "valid");
1570 }
1571
1572 #[test]
1573 fn malformed_hash_slot_does_not_short_circuit() {
1574 let (token, hash) = generate_api_key().unwrap();
1575 let keys = vec![
1576 ApiKeyEntry {
1577 name: "broken".into(),
1578 hash: "this-is-not-a-phc-string".into(),
1579 role: "viewer".into(),
1580 expires_at: None,
1581 },
1582 ApiKeyEntry {
1583 name: "valid".into(),
1584 hash,
1585 role: "ops".into(),
1586 expires_at: None,
1587 },
1588 ];
1589 let id = verify_bearer_token(&token, &keys)
1590 .expect("valid slot following a malformed-hash slot must authenticate");
1591 assert_eq!(id.name, "valid");
1592 }
1593
1594 #[test]
1605 fn rfc_timestamp_parse_rejects_malformed() {
1606 for bad in [
1607 "not-a-date",
1608 "",
1609 "2025-13-01T00:00:00Z", "2025-01-32T00:00:00Z", "2025-01-01T00:00:00", "01/01/2025", "2025-01-01T25:00:00Z", ] {
1615 assert!(
1616 RfcTimestamp::parse(bad).is_err(),
1617 "RfcTimestamp::parse must reject {bad:?}"
1618 );
1619 }
1620 }
1621
1622 #[test]
1623 fn rfc_timestamp_parse_accepts_valid() {
1624 for good in [
1625 "2025-01-01T00:00:00Z",
1626 "2025-01-01T00:00:00+00:00",
1627 "2025-12-31T23:59:59-08:00",
1628 "2099-01-01T00:00:00.123456789Z",
1629 ] {
1630 assert!(
1631 RfcTimestamp::parse(good).is_ok(),
1632 "RfcTimestamp::parse must accept {good:?}"
1633 );
1634 }
1635 }
1636
1637 #[test]
1638 fn api_key_entry_deserialize_rejects_malformed_expires_at() {
1639 let toml = r#"
1644 name = "bad-key"
1645 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1646 role = "viewer"
1647 expires_at = "not-a-date"
1648 "#;
1649 let result: Result<ApiKeyEntry, _> = toml::from_str(toml);
1650 assert!(
1651 result.is_err(),
1652 "deserialization must reject malformed expires_at"
1653 );
1654 }
1655
1656 #[test]
1657 fn api_key_entry_deserialize_accepts_valid_expires_at() {
1658 let toml = r#"
1659 name = "good-key"
1660 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1661 role = "viewer"
1662 expires_at = "2099-01-01T00:00:00Z"
1663 "#;
1664 let entry: ApiKeyEntry = toml::from_str(toml).expect("valid RFC 3339 must deserialize");
1665 assert!(entry.expires_at.is_some());
1666 }
1667
1668 #[test]
1669 fn api_key_entry_deserialize_accepts_missing_expires_at() {
1670 let toml = r#"
1673 name = "eternal-key"
1674 hash = "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$h4sh"
1675 role = "viewer"
1676 "#;
1677 let entry: ApiKeyEntry = toml::from_str(toml).expect("missing expires_at must deserialize");
1678 assert!(entry.expires_at.is_none());
1679 }
1680
1681 #[test]
1682 fn try_with_expiry_rejects_malformed() {
1683 let entry = ApiKeyEntry::new("k", "hash", "viewer");
1684 assert!(entry.try_with_expiry("not-a-date").is_err());
1685 }
1686
1687 #[test]
1688 fn try_with_expiry_accepts_valid() {
1689 let entry = ApiKeyEntry::new("k", "hash", "viewer")
1690 .try_with_expiry("2099-01-01T00:00:00Z")
1691 .expect("valid RFC 3339 must be accepted");
1692 assert!(entry.expires_at.is_some());
1693 }
1694
1695 #[test]
1696 fn api_key_summary_serializes_expires_at_as_rfc3339() {
1697 let summary = ApiKeySummary {
1702 name: "k".into(),
1703 role: "viewer".into(),
1704 expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
1705 };
1706 let json = serde_json::to_string(&summary).unwrap();
1707 assert!(
1708 json.contains(r#""expires_at":"2030-01-01T00:00:00+00:00""#),
1709 "wire format regressed: {json}"
1710 );
1711 }
1712
1713 #[test]
1714 fn future_expiry_accepted() {
1715 let (token, hash) = generate_api_key().unwrap();
1716 let keys = vec![ApiKeyEntry {
1717 name: "test".into(),
1718 hash,
1719 role: "viewer".into(),
1720 expires_at: Some(RfcTimestamp::parse("2099-01-01T00:00:00Z").unwrap()),
1721 }];
1722 assert!(verify_bearer_token(&token, &keys).is_some());
1723 }
1724
1725 #[test]
1726 fn multiple_keys_first_match_wins() {
1727 let (token, hash) = generate_api_key().unwrap();
1728 let keys = vec![
1729 ApiKeyEntry {
1730 name: "wrong".into(),
1731 hash: "$argon2id$v=19$m=19456,t=2,p=1$invalid$invalid".into(),
1732 role: "ops".into(),
1733 expires_at: None,
1734 },
1735 ApiKeyEntry {
1736 name: "correct".into(),
1737 hash,
1738 role: "deploy".into(),
1739 expires_at: None,
1740 },
1741 ];
1742 let id = verify_bearer_token(&token, &keys).unwrap();
1743 assert_eq!(id.name, "correct");
1744 assert_eq!(id.role, "deploy");
1745 }
1746
1747 #[test]
1748 fn rate_limiter_allows_within_quota() {
1749 let config = RateLimitConfig {
1750 max_attempts_per_minute: 5,
1751 pre_auth_max_per_minute: None,
1752 max_tracked_keys: default_max_tracked_keys(),
1753 idle_eviction: default_idle_eviction(),
1754 burst: None,
1755 pre_auth_burst: None,
1756 };
1757 let limiter = build_rate_limiter(&config);
1758 let ip: IpAddr = "10.0.0.1".parse().unwrap();
1759
1760 for _ in 0..5 {
1762 assert!(limiter.check_key(&ip).is_ok());
1763 }
1764 assert!(limiter.check_key(&ip).is_err());
1766 }
1767
1768 #[test]
1769 fn rate_limiter_separate_ips() {
1770 let config = RateLimitConfig {
1771 max_attempts_per_minute: 2,
1772 pre_auth_max_per_minute: None,
1773 max_tracked_keys: default_max_tracked_keys(),
1774 idle_eviction: default_idle_eviction(),
1775 burst: None,
1776 pre_auth_burst: None,
1777 };
1778 let limiter = build_rate_limiter(&config);
1779 let ip1: IpAddr = "10.0.0.1".parse().unwrap();
1780 let ip2: IpAddr = "10.0.0.2".parse().unwrap();
1781
1782 assert!(limiter.check_key(&ip1).is_ok());
1784 assert!(limiter.check_key(&ip1).is_ok());
1785 assert!(limiter.check_key(&ip1).is_err());
1786
1787 assert!(limiter.check_key(&ip2).is_ok());
1789 }
1790
1791 #[test]
1792 fn extract_mtls_identity_from_cn() {
1793 let mut params = rcgen::CertificateParams::new(vec!["test-client.local".into()]).unwrap();
1795 params.distinguished_name = rcgen::DistinguishedName::new();
1796 params
1797 .distinguished_name
1798 .push(rcgen::DnType::CommonName, "test-client");
1799 let cert = params
1800 .self_signed(&rcgen::KeyPair::generate().unwrap())
1801 .unwrap();
1802 let der = cert.der();
1803
1804 let id = extract_mtls_identity(der, "ops").unwrap();
1805 assert_eq!(id.name, "test-client");
1806 assert_eq!(id.role, "ops");
1807 assert_eq!(id.method, AuthMethod::MtlsCertificate);
1808 }
1809
1810 #[test]
1811 fn extract_mtls_identity_falls_back_to_san() {
1812 let mut params =
1814 rcgen::CertificateParams::new(vec!["san-only.example.com".into()]).unwrap();
1815 params.distinguished_name = rcgen::DistinguishedName::new();
1816 let cert = params
1818 .self_signed(&rcgen::KeyPair::generate().unwrap())
1819 .unwrap();
1820 let der = cert.der();
1821
1822 let id = extract_mtls_identity(der, "viewer").unwrap();
1823 assert_eq!(id.name, "san-only.example.com");
1824 assert_eq!(id.role, "viewer");
1825 }
1826
1827 #[test]
1828 fn extract_mtls_identity_invalid_der() {
1829 assert!(extract_mtls_identity(b"not-a-cert", "viewer").is_none());
1830 }
1831
1832 use axum::{
1835 body::Body,
1836 http::{Request, StatusCode},
1837 };
1838 use tower::ServiceExt as _;
1839
1840 fn auth_router(state: Arc<AuthState>) -> axum::Router {
1841 axum::Router::new()
1842 .route("/mcp", axum::routing::post(|| async { "ok" }))
1843 .layer(axum::middleware::from_fn(move |req, next| {
1844 let s = Arc::clone(&state);
1845 auth_middleware(s, req, next)
1846 }))
1847 }
1848
1849 fn test_auth_state(keys: Vec<ApiKeyEntry>) -> Arc<AuthState> {
1850 Arc::new(AuthState {
1851 api_keys: ArcSwap::new(Arc::new(keys)),
1852 rate_limiter: None,
1853 pre_auth_limiter: None,
1854 #[cfg(feature = "oauth")]
1855 jwks_cache: None,
1856 seen_identities: SeenIdentitySet::new(),
1857 counters: AuthCounters::default(),
1858 })
1859 }
1860
1861 #[tokio::test]
1862 async fn middleware_rejects_no_credentials() {
1863 let state = test_auth_state(vec![]);
1864 let app = auth_router(Arc::clone(&state));
1865 let req = Request::builder()
1866 .method(axum::http::Method::POST)
1867 .uri("/mcp")
1868 .body(Body::empty())
1869 .unwrap();
1870 let resp = app.oneshot(req).await.unwrap();
1871 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1872 let challenge = resp
1873 .headers()
1874 .get(header::WWW_AUTHENTICATE)
1875 .unwrap()
1876 .to_str()
1877 .unwrap();
1878 assert!(challenge.contains("error=\"invalid_request\""));
1879
1880 let counters = state.counters_snapshot();
1881 assert_eq!(counters.failure_missing_credential, 1);
1882 }
1883
1884 #[tokio::test]
1885 async fn middleware_accepts_valid_bearer() {
1886 let (token, hash) = generate_api_key().unwrap();
1887 let keys = vec![ApiKeyEntry {
1888 name: "test-key".into(),
1889 hash,
1890 role: "ops".into(),
1891 expires_at: None,
1892 }];
1893 let state = test_auth_state(keys);
1894 let app = auth_router(Arc::clone(&state));
1895 let req = Request::builder()
1896 .method(axum::http::Method::POST)
1897 .uri("/mcp")
1898 .header("authorization", format!("Bearer {token}"))
1899 .body(Body::empty())
1900 .unwrap();
1901 let resp = app.oneshot(req).await.unwrap();
1902 assert_eq!(resp.status(), StatusCode::OK);
1903
1904 let counters = state.counters_snapshot();
1905 assert_eq!(counters.success_bearer, 1);
1906 }
1907
1908 #[tokio::test]
1909 async fn middleware_rejects_wrong_bearer() {
1910 let (_token, hash) = generate_api_key().unwrap();
1911 let keys = vec![ApiKeyEntry {
1912 name: "test-key".into(),
1913 hash,
1914 role: "ops".into(),
1915 expires_at: None,
1916 }];
1917 let state = test_auth_state(keys);
1918 let app = auth_router(Arc::clone(&state));
1919 let req = Request::builder()
1920 .method(axum::http::Method::POST)
1921 .uri("/mcp")
1922 .header("authorization", "Bearer wrong-token-here")
1923 .body(Body::empty())
1924 .unwrap();
1925 let resp = app.oneshot(req).await.unwrap();
1926 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1927 let challenge = resp
1928 .headers()
1929 .get(header::WWW_AUTHENTICATE)
1930 .unwrap()
1931 .to_str()
1932 .unwrap();
1933 assert!(challenge.contains("error=\"invalid_token\""));
1934
1935 let counters = state.counters_snapshot();
1936 assert_eq!(counters.failure_invalid_credential, 1);
1937 }
1938
1939 #[tokio::test]
1940 async fn middleware_rate_limits() {
1941 let state = Arc::new(AuthState {
1942 api_keys: ArcSwap::new(Arc::new(vec![])),
1943 rate_limiter: Some(build_rate_limiter(&RateLimitConfig {
1944 max_attempts_per_minute: 1,
1945 pre_auth_max_per_minute: None,
1946 max_tracked_keys: default_max_tracked_keys(),
1947 idle_eviction: default_idle_eviction(),
1948 burst: None,
1949 pre_auth_burst: None,
1950 })),
1951 pre_auth_limiter: None,
1952 #[cfg(feature = "oauth")]
1953 jwks_cache: None,
1954 seen_identities: SeenIdentitySet::new(),
1955 counters: AuthCounters::default(),
1956 });
1957 let app = auth_router(state);
1958
1959 let req = Request::builder()
1961 .method(axum::http::Method::POST)
1962 .uri("/mcp")
1963 .body(Body::empty())
1964 .unwrap();
1965 let resp = app.clone().oneshot(req).await.unwrap();
1966 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
1967
1968 }
1973
1974 #[test]
1980 fn rate_limit_semantics_failed_only() {
1981 let config = RateLimitConfig {
1982 max_attempts_per_minute: 3,
1983 pre_auth_max_per_minute: None,
1984 max_tracked_keys: default_max_tracked_keys(),
1985 idle_eviction: default_idle_eviction(),
1986 burst: None,
1987 pre_auth_burst: None,
1988 };
1989 let limiter = build_rate_limiter(&config);
1990 let ip: IpAddr = "192.168.1.100".parse().unwrap();
1991
1992 assert!(
1994 limiter.check_key(&ip).is_ok(),
1995 "failure 1 should be allowed"
1996 );
1997 assert!(
1998 limiter.check_key(&ip).is_ok(),
1999 "failure 2 should be allowed"
2000 );
2001 assert!(
2002 limiter.check_key(&ip).is_ok(),
2003 "failure 3 should be allowed"
2004 );
2005 assert!(
2006 limiter.check_key(&ip).is_err(),
2007 "failure 4 should be blocked"
2008 );
2009
2010 }
2019
2020 #[test]
2025 fn pre_auth_default_multiplier_is_10x() {
2026 let config = RateLimitConfig {
2027 max_attempts_per_minute: 5,
2028 pre_auth_max_per_minute: None,
2029 max_tracked_keys: default_max_tracked_keys(),
2030 idle_eviction: default_idle_eviction(),
2031 burst: None,
2032 pre_auth_burst: None,
2033 };
2034 let limiter = build_pre_auth_limiter(&config);
2035 let ip: IpAddr = "10.0.0.1".parse().unwrap();
2036
2037 for i in 0..50 {
2039 assert!(
2040 limiter.check_key(&ip).is_ok(),
2041 "pre-auth attempt {i} (of expected 50) should be allowed under default 10x multiplier"
2042 );
2043 }
2044 assert!(
2046 limiter.check_key(&ip).is_err(),
2047 "pre-auth attempt 51 should be blocked (quota is 50, not unbounded)"
2048 );
2049 }
2050
2051 #[test]
2054 fn pre_auth_explicit_override_wins() {
2055 let config = RateLimitConfig {
2056 max_attempts_per_minute: 100, pre_auth_max_per_minute: Some(2), max_tracked_keys: default_max_tracked_keys(),
2059 idle_eviction: default_idle_eviction(),
2060 burst: None,
2061 pre_auth_burst: None,
2062 };
2063 let limiter = build_pre_auth_limiter(&config);
2064 let ip: IpAddr = "10.0.0.2".parse().unwrap();
2065
2066 assert!(limiter.check_key(&ip).is_ok(), "attempt 1 allowed");
2067 assert!(limiter.check_key(&ip).is_ok(), "attempt 2 allowed");
2068 assert!(
2069 limiter.check_key(&ip).is_err(),
2070 "attempt 3 must be blocked (explicit override of 2 wins over 10x default of 1000)"
2071 );
2072 }
2073
2074 #[test]
2076 fn pre_auth_gate_deny_sets_retry_after() {
2077 let config = RateLimitConfig::new(100).with_pre_auth_max_per_minute(1);
2078 let state = AuthState {
2079 api_keys: ArcSwap::new(Arc::new(vec![])),
2080 rate_limiter: None,
2081 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
2082 #[cfg(feature = "oauth")]
2083 jwks_cache: None,
2084 seen_identities: SeenIdentitySet::new(),
2085 counters: AuthCounters::default(),
2086 };
2087 let ip: IpAddr = "10.7.7.7".parse().unwrap();
2088 assert!(
2089 pre_auth_gate(&state, Some(ip)).is_none(),
2090 "first request within quota"
2091 );
2092 let resp = pre_auth_gate(&state, Some(ip)).expect("second request must be gated");
2093 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
2094 let retry_after = resp
2095 .headers()
2096 .get(header::RETRY_AFTER)
2097 .expect("Retry-After present")
2098 .to_str()
2099 .unwrap()
2100 .parse::<u64>()
2101 .unwrap();
2102 assert!(retry_after >= 1, "delta-seconds must be >= 1");
2103 }
2104
2105 #[test]
2107 fn post_failure_limiter_burst_allows_initial_spike() {
2108 let config = RateLimitConfig::new(1).with_burst(3);
2109 let limiter = build_rate_limiter(&config);
2110 let ip: IpAddr = "10.6.6.6".parse().unwrap();
2111 for i in 0..3 {
2112 assert!(limiter.check_key(&ip).is_ok(), "burst attempt {i}");
2113 }
2114 assert!(
2115 limiter.check_key(&ip).is_err(),
2116 "attempt 4 must exceed the burst bucket"
2117 );
2118 }
2119
2120 #[tokio::test]
2126 async fn pre_auth_gate_blocks_before_argon2_verification() {
2127 let (_token, hash) = generate_api_key().unwrap();
2128 let keys = vec![ApiKeyEntry {
2129 name: "test-key".into(),
2130 hash,
2131 role: "ops".into(),
2132 expires_at: None,
2133 }];
2134 let config = RateLimitConfig {
2135 max_attempts_per_minute: 100,
2136 pre_auth_max_per_minute: Some(1),
2137 max_tracked_keys: default_max_tracked_keys(),
2138 idle_eviction: default_idle_eviction(),
2139 burst: None,
2140 pre_auth_burst: None,
2141 };
2142 let state = Arc::new(AuthState {
2143 api_keys: ArcSwap::new(Arc::new(keys)),
2144 rate_limiter: None,
2145 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
2146 #[cfg(feature = "oauth")]
2147 jwks_cache: None,
2148 seen_identities: SeenIdentitySet::new(),
2149 counters: AuthCounters::default(),
2150 });
2151 let app = auth_router(Arc::clone(&state));
2152 let peer: SocketAddr = "10.0.0.10:54321".parse().unwrap();
2153
2154 let mut req1 = Request::builder()
2157 .method(axum::http::Method::POST)
2158 .uri("/mcp")
2159 .header("authorization", "Bearer obviously-not-a-real-token")
2160 .body(Body::empty())
2161 .unwrap();
2162 req1.extensions_mut().insert(ConnectInfo(peer));
2163 let resp1 = app.clone().oneshot(req1).await.unwrap();
2164 assert_eq!(
2165 resp1.status(),
2166 StatusCode::UNAUTHORIZED,
2167 "first attempt: gate has quota, falls through to bearer auth which fails with 401"
2168 );
2169
2170 let mut req2 = Request::builder()
2173 .method(axum::http::Method::POST)
2174 .uri("/mcp")
2175 .header("authorization", "Bearer also-not-a-real-token")
2176 .body(Body::empty())
2177 .unwrap();
2178 req2.extensions_mut().insert(ConnectInfo(peer));
2179 let resp2 = app.oneshot(req2).await.unwrap();
2180 assert_eq!(
2181 resp2.status(),
2182 StatusCode::TOO_MANY_REQUESTS,
2183 "second attempt from same IP: pre-auth gate must reject with 429"
2184 );
2185
2186 let counters = state.counters_snapshot();
2187 assert_eq!(
2188 counters.failure_pre_auth_gate, 1,
2189 "exactly one request must have been rejected by the pre-auth gate"
2190 );
2191 assert_eq!(
2195 counters.failure_invalid_credential, 1,
2196 "bearer verification must run exactly once (only the un-gated first request)"
2197 );
2198 }
2199
2200 #[tokio::test]
2207 async fn pre_auth_gate_does_not_throttle_mtls() {
2208 let config = RateLimitConfig {
2209 max_attempts_per_minute: 100,
2210 pre_auth_max_per_minute: Some(1), max_tracked_keys: default_max_tracked_keys(),
2212 idle_eviction: default_idle_eviction(),
2213 burst: None,
2214 pre_auth_burst: None,
2215 };
2216 let state = Arc::new(AuthState {
2217 api_keys: ArcSwap::new(Arc::new(vec![])),
2218 rate_limiter: None,
2219 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
2220 #[cfg(feature = "oauth")]
2221 jwks_cache: None,
2222 seen_identities: SeenIdentitySet::new(),
2223 counters: AuthCounters::default(),
2224 });
2225 let app = auth_router(Arc::clone(&state));
2226 let peer: SocketAddr = "10.0.0.20:54321".parse().unwrap();
2227 let identity = AuthIdentity {
2228 name: "cn=test-client".into(),
2229 role: "viewer".into(),
2230 method: AuthMethod::MtlsCertificate,
2231 raw_token: None,
2232 sub: None,
2233 };
2234 let tls_info = TlsConnInfo::new(peer, Some(identity));
2235
2236 for i in 0..3 {
2237 let mut req = Request::builder()
2238 .method(axum::http::Method::POST)
2239 .uri("/mcp")
2240 .body(Body::empty())
2241 .unwrap();
2242 req.extensions_mut().insert(ConnectInfo(tls_info.clone()));
2243 let resp = app.clone().oneshot(req).await.unwrap();
2244 assert_eq!(
2245 resp.status(),
2246 StatusCode::OK,
2247 "mTLS request {i} must succeed: pre-auth gate must not apply to mTLS callers"
2248 );
2249 }
2250
2251 let counters = state.counters_snapshot();
2252 assert_eq!(
2253 counters.failure_pre_auth_gate, 0,
2254 "pre-auth gate counter must remain at zero: mTLS bypasses the gate"
2255 );
2256 assert_eq!(
2257 counters.success_mtls, 3,
2258 "all three mTLS requests must have been counted as successful"
2259 );
2260 }
2261
2262 #[cfg(feature = "metrics")]
2265 #[tokio::test]
2266 async fn pre_auth_gate_deny_increments_counter() {
2267 let config = RateLimitConfig {
2268 max_attempts_per_minute: 100,
2269 pre_auth_max_per_minute: Some(1),
2270 max_tracked_keys: default_max_tracked_keys(),
2271 idle_eviction: default_idle_eviction(),
2272 burst: None,
2273 pre_auth_burst: None,
2274 };
2275 let state = Arc::new(AuthState {
2276 api_keys: ArcSwap::new(Arc::new(vec![])),
2277 rate_limiter: None,
2278 pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
2279 #[cfg(feature = "oauth")]
2280 jwks_cache: None,
2281 seen_identities: SeenIdentitySet::new(),
2282 counters: AuthCounters::default(),
2283 });
2284 let app = auth_router(Arc::clone(&state));
2285 let metrics = Arc::new(crate::metrics::McpMetrics::new().expect("metrics registry"));
2286 let peer: SocketAddr = "10.0.0.30:54321".parse().expect("addr parses");
2287 let mk = || {
2288 let mut req = Request::builder()
2289 .method(axum::http::Method::POST)
2290 .uri("/mcp")
2291 .header("authorization", "Bearer not-a-real-token")
2292 .body(Body::empty())
2293 .expect("request builds");
2294 req.extensions_mut().insert(ConnectInfo(peer));
2295 req.extensions_mut().insert(Arc::clone(&metrics));
2296 req
2297 };
2298 let counter = |label: &str| metrics.rate_limited_total.with_label_values(&[label]).get();
2299
2300 let first = app.clone().oneshot(mk()).await.expect("first request");
2301 assert_eq!(first.status(), StatusCode::UNAUTHORIZED);
2302 assert_eq!(counter("auth_pre"), 0, "un-gated request must not count");
2303
2304 let gated = app.oneshot(mk()).await.expect("second request");
2305 assert_eq!(gated.status(), StatusCode::TOO_MANY_REQUESTS);
2306 assert_eq!(counter("auth_pre"), 1, "gated request must count once");
2307 assert_eq!(counter("auth_post"), 0, "post limiter never fired");
2308 }
2309
2310 #[cfg(feature = "metrics")]
2313 #[tokio::test]
2314 async fn post_failure_limiter_deny_increments_counter() {
2315 let config = RateLimitConfig {
2316 max_attempts_per_minute: 1, pre_auth_max_per_minute: None,
2318 max_tracked_keys: default_max_tracked_keys(),
2319 idle_eviction: default_idle_eviction(),
2320 burst: None,
2321 pre_auth_burst: None,
2322 };
2323 let state = Arc::new(AuthState {
2324 api_keys: ArcSwap::new(Arc::new(vec![])),
2325 rate_limiter: Some(build_rate_limiter(&config)),
2326 pre_auth_limiter: None,
2327 #[cfg(feature = "oauth")]
2328 jwks_cache: None,
2329 seen_identities: SeenIdentitySet::new(),
2330 counters: AuthCounters::default(),
2331 });
2332 let app = auth_router(Arc::clone(&state));
2333 let metrics = Arc::new(crate::metrics::McpMetrics::new().expect("metrics registry"));
2334 let peer: SocketAddr = "10.0.0.31:54321".parse().expect("addr parses");
2335 let mk = || {
2336 let mut req = Request::builder()
2337 .method(axum::http::Method::POST)
2338 .uri("/mcp")
2339 .header("authorization", "Bearer not-a-real-token")
2340 .body(Body::empty())
2341 .expect("request builds");
2342 req.extensions_mut().insert(ConnectInfo(peer));
2343 req.extensions_mut().insert(Arc::clone(&metrics));
2344 req
2345 };
2346 let counter = |label: &str| metrics.rate_limited_total.with_label_values(&[label]).get();
2347
2348 let first = app.clone().oneshot(mk()).await.expect("first request");
2350 assert_eq!(first.status(), StatusCode::UNAUTHORIZED);
2351 assert_eq!(counter("auth_post"), 0);
2352
2353 let limited = app.oneshot(mk()).await.expect("second request");
2355 assert_eq!(limited.status(), StatusCode::TOO_MANY_REQUESTS);
2356 assert_eq!(counter("auth_post"), 1, "deny must count once");
2357 assert_eq!(counter("auth_pre"), 0, "pre-auth gate disabled here");
2358 }
2359
2360 #[test]
2365 fn extract_bearer_accepts_canonical_case() {
2366 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2367 }
2368
2369 #[test]
2370 fn extract_bearer_is_case_insensitive_per_rfc7235() {
2371 for header in &[
2375 "bearer abc123",
2376 "BEARER abc123",
2377 "BeArEr abc123",
2378 "bEaReR abc123",
2379 ] {
2380 assert_eq!(
2381 extract_bearer(header),
2382 Some("abc123"),
2383 "header {header:?} must parse as a Bearer token (RFC 7235 §2.1)"
2384 );
2385 }
2386 }
2387
2388 #[test]
2389 fn extract_bearer_rejects_other_schemes() {
2390 assert_eq!(extract_bearer("Basic dXNlcjpwYXNz"), None);
2391 assert_eq!(extract_bearer("Digest username=\"x\""), None);
2392 assert_eq!(extract_bearer("Token abc123"), None);
2393 }
2394
2395 #[test]
2396 fn extract_bearer_rejects_malformed() {
2397 assert_eq!(extract_bearer(""), None);
2399 assert_eq!(extract_bearer("Bearer"), None);
2400 assert_eq!(extract_bearer("Bearer "), None);
2401 assert_eq!(extract_bearer("Bearer "), None);
2402 }
2403
2404 #[test]
2405 fn extract_bearer_tolerates_extra_separator_whitespace() {
2406 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2408 assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
2409 }
2410
2411 #[test]
2417 fn auth_identity_debug_redacts_raw_token() {
2418 let id = AuthIdentity {
2419 name: "alice".into(),
2420 role: "admin".into(),
2421 method: AuthMethod::OAuthJwt,
2422 raw_token: Some(SecretString::from("super-secret-jwt-payload-xyz")),
2423 sub: Some("keycloak-uuid-2f3c8b".into()),
2424 };
2425 let dbg = format!("{id:?}");
2426
2427 assert!(dbg.contains("alice"), "name should be visible: {dbg}");
2429 assert!(dbg.contains("admin"), "role should be visible: {dbg}");
2430 assert!(dbg.contains("OAuthJwt"), "method should be visible: {dbg}");
2431
2432 assert!(
2434 !dbg.contains("super-secret-jwt-payload-xyz"),
2435 "raw_token must be redacted in Debug output: {dbg}"
2436 );
2437 assert!(
2438 !dbg.contains("keycloak-uuid-2f3c8b"),
2439 "sub must be redacted in Debug output: {dbg}"
2440 );
2441 assert!(
2442 dbg.contains("<redacted>"),
2443 "redaction marker missing: {dbg}"
2444 );
2445 }
2446
2447 #[test]
2448 fn auth_identity_debug_marks_absent_secrets() {
2449 let id = AuthIdentity {
2452 name: "viewer-key".into(),
2453 role: "viewer".into(),
2454 method: AuthMethod::BearerToken,
2455 raw_token: None,
2456 sub: None,
2457 };
2458 let dbg = format!("{id:?}");
2459 assert!(
2460 dbg.contains("<none>"),
2461 "absent secrets should be marked: {dbg}"
2462 );
2463 assert!(
2464 !dbg.contains("<redacted>"),
2465 "no <redacted> marker when secrets are absent: {dbg}"
2466 );
2467 }
2468
2469 #[test]
2470 fn api_key_entry_debug_redacts_hash() {
2471 let entry = ApiKeyEntry {
2472 name: "viewer-key".into(),
2473 hash: "$argon2id$v=19$m=19456,t=2,p=1$c2FsdHNhbHQ$h4sh3dPa55w0rd".into(),
2475 role: "viewer".into(),
2476 expires_at: Some(RfcTimestamp::parse("2030-01-01T00:00:00Z").unwrap()),
2477 };
2478 let dbg = format!("{entry:?}");
2479
2480 assert!(dbg.contains("viewer-key"));
2482 assert!(dbg.contains("viewer"));
2483 assert!(dbg.contains("2030-01-01T00:00:00+00:00"));
2484
2485 assert!(
2487 !dbg.contains("$argon2id$"),
2488 "argon2 hash leaked into Debug output: {dbg}"
2489 );
2490 assert!(
2491 !dbg.contains("h4sh3dPa55w0rd"),
2492 "hash digest leaked into Debug output: {dbg}"
2493 );
2494 assert!(
2495 dbg.contains("<redacted>"),
2496 "redaction marker missing: {dbg}"
2497 );
2498 }
2499
2500 #[test]
2511 fn auth_failure_class_as_str_exact_strings() {
2512 assert_eq!(
2513 AuthFailureClass::MissingCredential.as_str(),
2514 "missing_credential"
2515 );
2516 assert_eq!(
2517 AuthFailureClass::InvalidCredential.as_str(),
2518 "invalid_credential"
2519 );
2520 assert_eq!(
2521 AuthFailureClass::ExpiredCredential.as_str(),
2522 "expired_credential"
2523 );
2524 assert_eq!(AuthFailureClass::RateLimited.as_str(), "rate_limited");
2525 assert_eq!(AuthFailureClass::PreAuthGate.as_str(), "pre_auth_gate");
2526 }
2527
2528 #[test]
2529 fn auth_failure_class_response_body_exact_strings() {
2530 assert_eq!(
2531 AuthFailureClass::MissingCredential.response_body(),
2532 "unauthorized: missing credential"
2533 );
2534 assert_eq!(
2535 AuthFailureClass::InvalidCredential.response_body(),
2536 "unauthorized: invalid credential"
2537 );
2538 assert_eq!(
2539 AuthFailureClass::ExpiredCredential.response_body(),
2540 "unauthorized: expired credential"
2541 );
2542 assert_eq!(
2543 AuthFailureClass::RateLimited.response_body(),
2544 "rate limited"
2545 );
2546 assert_eq!(
2547 AuthFailureClass::PreAuthGate.response_body(),
2548 "rate limited (pre-auth)"
2549 );
2550 }
2551
2552 #[test]
2553 fn auth_failure_class_bearer_error_exact_strings() {
2554 assert_eq!(
2555 AuthFailureClass::MissingCredential.bearer_error(),
2556 (
2557 "invalid_request",
2558 "missing bearer token or mTLS client certificate"
2559 )
2560 );
2561 assert_eq!(
2562 AuthFailureClass::InvalidCredential.bearer_error(),
2563 ("invalid_token", "token is invalid")
2564 );
2565 assert_eq!(
2566 AuthFailureClass::ExpiredCredential.bearer_error(),
2567 ("invalid_token", "token is expired")
2568 );
2569 assert_eq!(
2570 AuthFailureClass::RateLimited.bearer_error(),
2571 ("invalid_request", "too many failed authentication attempts")
2572 );
2573 assert_eq!(
2574 AuthFailureClass::PreAuthGate.bearer_error(),
2575 (
2576 "invalid_request",
2577 "too many unauthenticated requests from this source"
2578 )
2579 );
2580 }
2581
2582 #[test]
2591 fn auth_config_summary_bearer_true_when_keys_present() {
2592 let (_token, hash) = generate_api_key().unwrap();
2593 let cfg = AuthConfig::with_keys(vec![ApiKeyEntry::new("k", hash, "viewer")]);
2594 let s = cfg.summary();
2595 assert!(s.enabled, "summary.enabled must reflect AuthConfig.enabled");
2596 assert!(
2597 s.bearer,
2598 "summary.bearer must be true when api_keys is non-empty (kills `!` deletion at L615)"
2599 );
2600 assert!(!s.mtls, "summary.mtls must be false when mtls is None");
2601 assert!(!s.oauth, "summary.oauth must be false when oauth is None");
2602 assert_eq!(s.api_keys.len(), 1);
2603 assert_eq!(s.api_keys[0].name, "k");
2604 assert_eq!(s.api_keys[0].role, "viewer");
2605 }
2606
2607 #[test]
2608 fn auth_config_summary_bearer_false_when_no_keys() {
2609 let cfg = AuthConfig::with_keys(vec![]);
2610 let s = cfg.summary();
2611 assert!(
2612 !s.bearer,
2613 "summary.bearer must be false when api_keys is empty (kills `!` deletion at L615)"
2614 );
2615 assert!(s.api_keys.is_empty());
2616 }
2617
2618 #[test]
2619 fn seen_identity_set_first_then_repeat() {
2620 let set = SeenIdentitySet::new();
2621 assert!(set.insert_is_first("alice"), "first sighting is first");
2622 assert!(
2623 !set.insert_is_first("alice"),
2624 "second sighting is not first"
2625 );
2626 assert!(set.insert_is_first("bob"));
2627 assert_eq!(set.len(), 2);
2628 }
2629
2630 #[test]
2631 fn seen_identity_set_evicts_oldest_at_cap() {
2632 let set = SeenIdentitySet::with_cap(2);
2633 assert!(set.insert_is_first("a"));
2634 assert!(set.insert_is_first("b"));
2635 assert!(set.insert_is_first("c"));
2637 assert_eq!(set.len(), 2);
2638 assert!(set.insert_is_first("a"));
2642 assert_eq!(set.len(), 2);
2643 assert!(set.insert_is_first("b"));
2645 for i in 0..32 {
2647 set.insert_is_first(&format!("churn-{i}"));
2648 assert!(set.len() <= 2, "cap invariant must hold");
2649 }
2650 }
2651
2652 #[test]
2653 fn seen_identity_set_cap_zero_is_raised_to_one() {
2654 let set = SeenIdentitySet::with_cap(0);
2655 assert!(set.insert_is_first("only"));
2656 assert_eq!(set.len(), 1);
2657 assert!(set.insert_is_first("next"));
2659 assert_eq!(set.len(), 1);
2660 }
2661
2662 #[test]
2663 fn seen_identity_set_fifo_does_not_refresh_on_repeat_hit() {
2664 let set = SeenIdentitySet::with_cap(2);
2667 assert!(set.insert_is_first("a")); assert!(set.insert_is_first("b")); assert!(!set.insert_is_first("a"));
2673 assert!(set.insert_is_first("c"));
2676 assert!(set.insert_is_first("a"));
2678 let set = SeenIdentitySet::with_cap(2);
2684 assert!(set.insert_is_first("x")); assert!(set.insert_is_first("y")); assert!(!set.insert_is_first("x")); assert!(set.insert_is_first("z")); assert!(
2689 !set.insert_is_first("y"),
2690 "y must still be present (FIFO did not evict it)"
2691 );
2692 assert!(
2693 set.insert_is_first("x"),
2694 "x must have been evicted by FIFO (would NOT have been evicted under LRU)"
2695 );
2696 }
2697}