use std::{
collections::HashSet,
net::{IpAddr, SocketAddr},
num::NonZeroU32,
path::PathBuf,
sync::{
Arc, Mutex,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
use arc_swap::ArcSwap;
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
use axum::{
body::Body,
extract::ConnectInfo,
http::{Request, header},
middleware::Next,
response::{IntoResponse, Response},
};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use secrecy::SecretString;
use serde::Deserialize;
use x509_parser::prelude::*;
use crate::{bounded_limiter::BoundedKeyedLimiter, error::McpxError};
#[derive(Clone)]
#[non_exhaustive]
pub struct AuthIdentity {
pub name: String,
pub role: String,
pub method: AuthMethod,
pub raw_token: Option<SecretString>,
pub sub: Option<String>,
}
impl std::fmt::Debug for AuthIdentity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthIdentity")
.field("name", &self.name)
.field("role", &self.role)
.field("method", &self.method)
.field(
"raw_token",
&if self.raw_token.is_some() {
"<redacted>"
} else {
"<none>"
},
)
.field(
"sub",
&if self.sub.is_some() {
"<redacted>"
} else {
"<none>"
},
)
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum AuthMethod {
BearerToken,
MtlsCertificate,
OAuthJwt,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum AuthFailureClass {
MissingCredential,
InvalidCredential,
#[cfg_attr(not(feature = "oauth"), allow(dead_code))]
ExpiredCredential,
RateLimited,
PreAuthGate,
}
impl AuthFailureClass {
fn as_str(self) -> &'static str {
match self {
Self::MissingCredential => "missing_credential",
Self::InvalidCredential => "invalid_credential",
Self::ExpiredCredential => "expired_credential",
Self::RateLimited => "rate_limited",
Self::PreAuthGate => "pre_auth_gate",
}
}
fn bearer_error(self) -> (&'static str, &'static str) {
match self {
Self::MissingCredential => (
"invalid_request",
"missing bearer token or mTLS client certificate",
),
Self::InvalidCredential => ("invalid_token", "token is invalid"),
Self::ExpiredCredential => ("invalid_token", "token is expired"),
Self::RateLimited => ("invalid_request", "too many failed authentication attempts"),
Self::PreAuthGate => (
"invalid_request",
"too many unauthenticated requests from this source",
),
}
}
fn response_body(self) -> &'static str {
match self {
Self::MissingCredential => "unauthorized: missing credential",
Self::InvalidCredential => "unauthorized: invalid credential",
Self::ExpiredCredential => "unauthorized: expired credential",
Self::RateLimited => "rate limited",
Self::PreAuthGate => "rate limited (pre-auth)",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
#[non_exhaustive]
pub struct AuthCountersSnapshot {
pub success_mtls: u64,
pub success_bearer: u64,
pub success_oauth_jwt: u64,
pub failure_missing_credential: u64,
pub failure_invalid_credential: u64,
pub failure_expired_credential: u64,
pub failure_rate_limited: u64,
pub failure_pre_auth_gate: u64,
}
#[derive(Debug, Default)]
pub(crate) struct AuthCounters {
success_mtls: AtomicU64,
success_bearer: AtomicU64,
success_oauth_jwt: AtomicU64,
failure_missing_credential: AtomicU64,
failure_invalid_credential: AtomicU64,
failure_expired_credential: AtomicU64,
failure_rate_limited: AtomicU64,
failure_pre_auth_gate: AtomicU64,
}
impl AuthCounters {
fn record_success(&self, method: AuthMethod) {
match method {
AuthMethod::MtlsCertificate => {
self.success_mtls.fetch_add(1, Ordering::Relaxed);
}
AuthMethod::BearerToken => {
self.success_bearer.fetch_add(1, Ordering::Relaxed);
}
AuthMethod::OAuthJwt => {
self.success_oauth_jwt.fetch_add(1, Ordering::Relaxed);
}
}
}
fn record_failure(&self, class: AuthFailureClass) {
match class {
AuthFailureClass::MissingCredential => {
self.failure_missing_credential
.fetch_add(1, Ordering::Relaxed);
}
AuthFailureClass::InvalidCredential => {
self.failure_invalid_credential
.fetch_add(1, Ordering::Relaxed);
}
AuthFailureClass::ExpiredCredential => {
self.failure_expired_credential
.fetch_add(1, Ordering::Relaxed);
}
AuthFailureClass::RateLimited => {
self.failure_rate_limited.fetch_add(1, Ordering::Relaxed);
}
AuthFailureClass::PreAuthGate => {
self.failure_pre_auth_gate.fetch_add(1, Ordering::Relaxed);
}
}
}
fn snapshot(&self) -> AuthCountersSnapshot {
AuthCountersSnapshot {
success_mtls: self.success_mtls.load(Ordering::Relaxed),
success_bearer: self.success_bearer.load(Ordering::Relaxed),
success_oauth_jwt: self.success_oauth_jwt.load(Ordering::Relaxed),
failure_missing_credential: self.failure_missing_credential.load(Ordering::Relaxed),
failure_invalid_credential: self.failure_invalid_credential.load(Ordering::Relaxed),
failure_expired_credential: self.failure_expired_credential.load(Ordering::Relaxed),
failure_rate_limited: self.failure_rate_limited.load(Ordering::Relaxed),
failure_pre_auth_gate: self.failure_pre_auth_gate.load(Ordering::Relaxed),
}
}
}
#[derive(Clone, Deserialize)]
#[non_exhaustive]
pub struct ApiKeyEntry {
pub name: String,
pub hash: String,
pub role: String,
pub expires_at: Option<String>,
}
impl std::fmt::Debug for ApiKeyEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ApiKeyEntry")
.field("name", &self.name)
.field("hash", &"<redacted>")
.field("role", &self.role)
.field("expires_at", &self.expires_at)
.finish()
}
}
impl ApiKeyEntry {
#[must_use]
pub fn new(name: impl Into<String>, hash: impl Into<String>, role: impl Into<String>) -> Self {
Self {
name: name.into(),
hash: hash.into(),
role: role.into(),
expires_at: None,
}
}
#[must_use]
pub fn with_expiry(mut self, expires_at: impl Into<String>) -> Self {
self.expires_at = Some(expires_at.into());
self
}
}
#[derive(Debug, Clone, Deserialize)]
#[allow(
clippy::struct_excessive_bools,
reason = "mTLS CRL behavior is intentionally configured as independent booleans"
)]
#[non_exhaustive]
pub struct MtlsConfig {
pub ca_cert_path: PathBuf,
#[serde(default)]
pub required: bool,
#[serde(default = "default_mtls_role")]
pub default_role: String,
#[serde(default = "default_true")]
pub crl_enabled: bool,
#[serde(default, with = "humantime_serde::option")]
pub crl_refresh_interval: Option<Duration>,
#[serde(default = "default_crl_fetch_timeout", with = "humantime_serde")]
pub crl_fetch_timeout: Duration,
#[serde(default = "default_crl_stale_grace", with = "humantime_serde")]
pub crl_stale_grace: Duration,
#[serde(default)]
pub crl_deny_on_unavailable: bool,
#[serde(default)]
pub crl_end_entity_only: bool,
#[serde(default = "default_true")]
pub crl_allow_http: bool,
#[serde(default = "default_true")]
pub crl_enforce_expiration: bool,
#[serde(default = "default_crl_max_concurrent_fetches")]
pub crl_max_concurrent_fetches: usize,
#[serde(default = "default_crl_max_response_bytes")]
pub crl_max_response_bytes: u64,
#[serde(default = "default_crl_discovery_rate_per_min")]
pub crl_discovery_rate_per_min: u32,
#[serde(default = "default_crl_max_host_semaphores")]
pub crl_max_host_semaphores: usize,
#[serde(default = "default_crl_max_seen_urls")]
pub crl_max_seen_urls: usize,
#[serde(default = "default_crl_max_cache_entries")]
pub crl_max_cache_entries: usize,
}
fn default_mtls_role() -> String {
"viewer".into()
}
const fn default_true() -> bool {
true
}
const fn default_crl_fetch_timeout() -> Duration {
Duration::from_secs(30)
}
const fn default_crl_stale_grace() -> Duration {
Duration::from_hours(24)
}
const fn default_crl_max_concurrent_fetches() -> usize {
4
}
const fn default_crl_max_response_bytes() -> u64 {
5 * 1024 * 1024
}
const fn default_crl_discovery_rate_per_min() -> u32 {
60
}
const fn default_crl_max_host_semaphores() -> usize {
1024
}
const fn default_crl_max_seen_urls() -> usize {
4096
}
const fn default_crl_max_cache_entries() -> usize {
1024
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct RateLimitConfig {
#[serde(default = "default_max_attempts")]
pub max_attempts_per_minute: u32,
#[serde(default)]
pub pre_auth_max_per_minute: Option<u32>,
#[serde(default = "default_max_tracked_keys")]
pub max_tracked_keys: usize,
#[serde(default = "default_idle_eviction", with = "humantime_serde")]
pub idle_eviction: Duration,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
max_attempts_per_minute: default_max_attempts(),
pre_auth_max_per_minute: None,
max_tracked_keys: default_max_tracked_keys(),
idle_eviction: default_idle_eviction(),
}
}
}
impl RateLimitConfig {
#[must_use]
pub fn new(max_attempts_per_minute: u32) -> Self {
Self {
max_attempts_per_minute,
..Self::default()
}
}
#[must_use]
pub fn with_pre_auth_max_per_minute(mut self, quota: u32) -> Self {
self.pre_auth_max_per_minute = Some(quota);
self
}
#[must_use]
pub fn with_max_tracked_keys(mut self, max: usize) -> Self {
self.max_tracked_keys = max;
self
}
#[must_use]
pub fn with_idle_eviction(mut self, idle: Duration) -> Self {
self.idle_eviction = idle;
self
}
}
fn default_max_attempts() -> u32 {
30
}
fn default_max_tracked_keys() -> usize {
10_000
}
fn default_idle_eviction() -> Duration {
Duration::from_mins(15)
}
#[derive(Debug, Clone, Default, Deserialize)]
#[non_exhaustive]
pub struct AuthConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub api_keys: Vec<ApiKeyEntry>,
pub mtls: Option<MtlsConfig>,
pub rate_limit: Option<RateLimitConfig>,
#[cfg(feature = "oauth")]
pub oauth: Option<crate::oauth::OAuthConfig>,
}
impl AuthConfig {
#[must_use]
pub fn with_keys(keys: Vec<ApiKeyEntry>) -> Self {
Self {
enabled: true,
api_keys: keys,
mtls: None,
rate_limit: None,
#[cfg(feature = "oauth")]
oauth: None,
}
}
#[must_use]
pub fn with_rate_limit(mut self, rate_limit: RateLimitConfig) -> Self {
self.rate_limit = Some(rate_limit);
self
}
}
#[derive(Debug, Clone, serde::Serialize)]
#[non_exhaustive]
pub struct ApiKeySummary {
pub name: String,
pub role: String,
pub expires_at: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize)]
#[allow(
clippy::struct_excessive_bools,
reason = "this is a flat summary of independent auth-method booleans"
)]
#[non_exhaustive]
pub struct AuthConfigSummary {
pub enabled: bool,
pub bearer: bool,
pub mtls: bool,
pub oauth: bool,
pub api_keys: Vec<ApiKeySummary>,
}
impl AuthConfig {
#[must_use]
pub fn summary(&self) -> AuthConfigSummary {
AuthConfigSummary {
enabled: self.enabled,
bearer: !self.api_keys.is_empty(),
mtls: self.mtls.is_some(),
#[cfg(feature = "oauth")]
oauth: self.oauth.is_some(),
#[cfg(not(feature = "oauth"))]
oauth: false,
api_keys: self
.api_keys
.iter()
.map(|k| ApiKeySummary {
name: k.name.clone(),
role: k.role.clone(),
expires_at: k.expires_at.clone(),
})
.collect(),
}
}
}
pub(crate) type KeyedLimiter = BoundedKeyedLimiter<IpAddr>;
#[derive(Clone, Debug)]
#[non_exhaustive]
pub(crate) struct TlsConnInfo {
pub addr: SocketAddr,
pub identity: Option<AuthIdentity>,
}
impl TlsConnInfo {
#[must_use]
pub(crate) const fn new(addr: SocketAddr, identity: Option<AuthIdentity>) -> Self {
Self { addr, identity }
}
}
#[allow(
missing_debug_implementations,
reason = "contains governor RateLimiter and JwksCache without Debug impls"
)]
#[non_exhaustive]
pub(crate) struct AuthState {
pub api_keys: ArcSwap<Vec<ApiKeyEntry>>,
pub rate_limiter: Option<Arc<KeyedLimiter>>,
pub pre_auth_limiter: Option<Arc<KeyedLimiter>>,
#[cfg(feature = "oauth")]
pub jwks_cache: Option<Arc<crate::oauth::JwksCache>>,
pub seen_identities: Mutex<HashSet<String>>,
pub counters: AuthCounters,
}
impl AuthState {
pub(crate) fn reload_keys(&self, keys: Vec<ApiKeyEntry>) {
let count = keys.len();
self.api_keys.store(Arc::new(keys));
tracing::info!(keys = count, "API keys reloaded");
}
#[must_use]
pub(crate) fn counters_snapshot(&self) -> AuthCountersSnapshot {
self.counters.snapshot()
}
#[must_use]
pub(crate) fn api_key_summaries(&self) -> Vec<ApiKeySummary> {
self.api_keys
.load()
.iter()
.map(|k| ApiKeySummary {
name: k.name.clone(),
role: k.role.clone(),
expires_at: k.expires_at.clone(),
})
.collect()
}
fn log_auth(&self, id: &AuthIdentity, method: &str) {
self.counters.record_success(id.method);
let first = self
.seen_identities
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.insert(id.name.clone());
if first {
tracing::info!(name = %id.name, role = %id.role, "{method} authenticated");
} else {
tracing::debug!(name = %id.name, role = %id.role, "{method} authenticated");
}
}
}
const DEFAULT_AUTH_RATE: NonZeroU32 = NonZeroU32::new(30).unwrap();
#[must_use]
pub(crate) fn build_rate_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
let quota = governor::Quota::per_minute(
NonZeroU32::new(config.max_attempts_per_minute).unwrap_or(DEFAULT_AUTH_RATE),
);
Arc::new(BoundedKeyedLimiter::new(
quota,
config.max_tracked_keys,
config.idle_eviction,
))
}
#[must_use]
pub(crate) fn build_pre_auth_limiter(config: &RateLimitConfig) -> Arc<KeyedLimiter> {
let resolved = config.pre_auth_max_per_minute.unwrap_or_else(|| {
config
.max_attempts_per_minute
.saturating_mul(PRE_AUTH_DEFAULT_MULTIPLIER)
});
let quota =
governor::Quota::per_minute(NonZeroU32::new(resolved).unwrap_or(DEFAULT_PRE_AUTH_RATE));
Arc::new(BoundedKeyedLimiter::new(
quota,
config.max_tracked_keys,
config.idle_eviction,
))
}
const PRE_AUTH_DEFAULT_MULTIPLIER: u32 = 10;
const DEFAULT_PRE_AUTH_RATE: NonZeroU32 = NonZeroU32::new(300).unwrap();
#[must_use]
pub fn extract_mtls_identity(cert_der: &[u8], default_role: &str) -> Option<AuthIdentity> {
let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
let cn = cert
.subject()
.iter_common_name()
.next()
.and_then(|attr| attr.as_str().ok())
.map(String::from);
let name = cn.or_else(|| {
cert.subject_alternative_name()
.ok()
.flatten()
.and_then(|san| {
#[allow(clippy::wildcard_enum_match_arm)]
san.value.general_names.iter().find_map(|gn| match gn {
GeneralName::DNSName(dns) => Some((*dns).to_owned()),
_ => None,
})
})
})?;
if !name
.chars()
.all(|c| c.is_alphanumeric() || matches!(c, '-' | '.' | '_' | '@'))
{
tracing::warn!(cn = %name, "mTLS identity rejected: invalid characters in CN/SAN");
return None;
}
Some(AuthIdentity {
name,
role: default_role.to_owned(),
method: AuthMethod::MtlsCertificate,
raw_token: None,
sub: None,
})
}
fn extract_bearer(value: &str) -> Option<&str> {
let (scheme, rest) = value.split_once(' ')?;
if scheme.eq_ignore_ascii_case("Bearer") {
let token = rest.trim_start_matches(' ');
if token.is_empty() { None } else { Some(token) }
} else {
None
}
}
#[must_use]
pub fn verify_bearer_token(token: &str, keys: &[ApiKeyEntry]) -> Option<AuthIdentity> {
let now = chrono::Utc::now();
let mut result: Option<AuthIdentity> = None;
for key in keys {
if let Some(ref expires) = key.expires_at
&& let Ok(exp) = chrono::DateTime::parse_from_rfc3339(expires)
&& exp < now
{
continue;
}
if result.is_none()
&& let Ok(parsed_hash) = PasswordHash::new(&key.hash)
&& Argon2::default()
.verify_password(token.as_bytes(), &parsed_hash)
.is_ok()
{
result = Some(AuthIdentity {
name: key.name.clone(),
role: key.role.clone(),
method: AuthMethod::BearerToken,
raw_token: None,
sub: None,
});
}
}
result
}
pub fn generate_api_key() -> Result<(String, String), McpxError> {
let mut token_bytes = [0u8; 32];
rand::fill(&mut token_bytes);
let token = URL_SAFE_NO_PAD.encode(token_bytes);
let mut salt_bytes = [0u8; 16];
rand::fill(&mut salt_bytes);
let salt = SaltString::encode_b64(&salt_bytes)
.map_err(|e| McpxError::Auth(format!("salt encoding failed: {e}")))?;
let hash = Argon2::default()
.hash_password(token.as_bytes(), &salt)
.map_err(|e| McpxError::Auth(format!("argon2id hashing failed: {e}")))?
.to_string();
Ok((token, hash))
}
fn build_www_authenticate_value(
advertise_resource_metadata: bool,
failure: AuthFailureClass,
) -> String {
let (error, error_description) = failure.bearer_error();
if advertise_resource_metadata {
return format!(
"Bearer resource_metadata=\"/.well-known/oauth-protected-resource\", error=\"{error}\", error_description=\"{error_description}\""
);
}
format!("Bearer error=\"{error}\", error_description=\"{error_description}\"")
}
fn auth_method_label(method: AuthMethod) -> &'static str {
match method {
AuthMethod::MtlsCertificate => "mTLS",
AuthMethod::BearerToken => "bearer token",
AuthMethod::OAuthJwt => "OAuth JWT",
}
}
#[cfg_attr(not(feature = "oauth"), allow(unused_variables))]
fn unauthorized_response(state: &AuthState, failure_class: AuthFailureClass) -> Response {
#[cfg(feature = "oauth")]
let advertise_resource_metadata = state.jwks_cache.is_some();
#[cfg(not(feature = "oauth"))]
let advertise_resource_metadata = false;
let challenge = build_www_authenticate_value(advertise_resource_metadata, failure_class);
(
axum::http::StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, challenge)],
failure_class.response_body(),
)
.into_response()
}
async fn authenticate_bearer_identity(
state: &AuthState,
token: &str,
) -> Result<AuthIdentity, AuthFailureClass> {
let mut failure_class = AuthFailureClass::MissingCredential;
#[cfg(feature = "oauth")]
if let Some(ref cache) = state.jwks_cache
&& crate::oauth::looks_like_jwt(token)
{
match cache.validate_token_with_reason(token).await {
Ok(mut id) => {
id.raw_token = Some(SecretString::from(token.to_owned()));
return Ok(id);
}
Err(crate::oauth::JwtValidationFailure::Expired) => {
failure_class = AuthFailureClass::ExpiredCredential;
}
Err(crate::oauth::JwtValidationFailure::Invalid) => {
failure_class = AuthFailureClass::InvalidCredential;
}
}
}
let token = token.to_owned();
let keys = state.api_keys.load_full();
let identity = tokio::task::spawn_blocking(move || verify_bearer_token(&token, &keys))
.await
.ok()
.flatten();
if let Some(id) = identity {
return Ok(id);
}
if failure_class == AuthFailureClass::MissingCredential {
failure_class = AuthFailureClass::InvalidCredential;
}
Err(failure_class)
}
fn pre_auth_gate(state: &AuthState, peer_addr: Option<SocketAddr>) -> Option<Response> {
let limiter = state.pre_auth_limiter.as_ref()?;
let addr = peer_addr?;
if limiter.check_key(&addr.ip()).is_ok() {
return None;
}
state.counters.record_failure(AuthFailureClass::PreAuthGate);
tracing::warn!(
ip = %addr.ip(),
"auth rate limited by pre-auth gate (request rejected before credential verification)"
);
Some(
McpxError::RateLimited("too many unauthenticated requests from this source".into())
.into_response(),
)
}
pub(crate) async fn auth_middleware(
state: Arc<AuthState>,
req: Request<Body>,
next: Next,
) -> Response {
let tls_info = req.extensions().get::<ConnectInfo<TlsConnInfo>>().cloned();
let peer_addr = req
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0)
.or_else(|| tls_info.as_ref().map(|ci| ci.0.addr));
if let Some(id) = tls_info.and_then(|ci| ci.0.identity) {
state.log_auth(&id, "mTLS");
let mut req = req;
req.extensions_mut().insert(id);
return next.run(req).await;
}
if let Some(blocked) = pre_auth_gate(&state, peer_addr) {
return blocked;
}
let failure_class = if let Some(value) = req.headers().get(header::AUTHORIZATION) {
match value.to_str().ok().and_then(extract_bearer) {
Some(token) => match authenticate_bearer_identity(&state, token).await {
Ok(id) => {
state.log_auth(&id, auth_method_label(id.method));
let mut req = req;
req.extensions_mut().insert(id);
return next.run(req).await;
}
Err(class) => class,
},
None => AuthFailureClass::InvalidCredential,
}
} else {
AuthFailureClass::MissingCredential
};
tracing::warn!(failure_class = %failure_class.as_str(), "auth failed");
if let (Some(limiter), Some(addr)) = (&state.rate_limiter, peer_addr)
&& limiter.check_key(&addr.ip()).is_err()
{
state.counters.record_failure(AuthFailureClass::RateLimited);
tracing::warn!(ip = %addr.ip(), "auth rate limited after repeated failures");
return McpxError::RateLimited("too many failed authentication attempts".into())
.into_response();
}
state.counters.record_failure(failure_class);
unauthorized_response(&state, failure_class)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generate_and_verify_api_key() {
let (token, hash) = generate_api_key().unwrap();
assert_eq!(token.len(), 43);
assert!(hash.starts_with("$argon2id$"));
let keys = vec![ApiKeyEntry {
name: "test".into(),
hash,
role: "viewer".into(),
expires_at: None,
}];
let id = verify_bearer_token(&token, &keys);
assert!(id.is_some());
let id = id.unwrap();
assert_eq!(id.name, "test");
assert_eq!(id.role, "viewer");
assert_eq!(id.method, AuthMethod::BearerToken);
}
#[test]
fn wrong_token_rejected() {
let (_token, hash) = generate_api_key().unwrap();
let keys = vec![ApiKeyEntry {
name: "test".into(),
hash,
role: "viewer".into(),
expires_at: None,
}];
assert!(verify_bearer_token("wrong-token", &keys).is_none());
}
#[test]
fn expired_key_rejected() {
let (token, hash) = generate_api_key().unwrap();
let keys = vec![ApiKeyEntry {
name: "test".into(),
hash,
role: "viewer".into(),
expires_at: Some("2020-01-01T00:00:00Z".into()),
}];
assert!(verify_bearer_token(&token, &keys).is_none());
}
#[test]
fn future_expiry_accepted() {
let (token, hash) = generate_api_key().unwrap();
let keys = vec![ApiKeyEntry {
name: "test".into(),
hash,
role: "viewer".into(),
expires_at: Some("2099-01-01T00:00:00Z".into()),
}];
assert!(verify_bearer_token(&token, &keys).is_some());
}
#[test]
fn multiple_keys_first_match_wins() {
let (token, hash) = generate_api_key().unwrap();
let keys = vec![
ApiKeyEntry {
name: "wrong".into(),
hash: "$argon2id$v=19$m=19456,t=2,p=1$invalid$invalid".into(),
role: "ops".into(),
expires_at: None,
},
ApiKeyEntry {
name: "correct".into(),
hash,
role: "deploy".into(),
expires_at: None,
},
];
let id = verify_bearer_token(&token, &keys).unwrap();
assert_eq!(id.name, "correct");
assert_eq!(id.role, "deploy");
}
#[test]
fn rate_limiter_allows_within_quota() {
let config = RateLimitConfig {
max_attempts_per_minute: 5,
pre_auth_max_per_minute: None,
..Default::default()
};
let limiter = build_rate_limiter(&config);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
for _ in 0..5 {
assert!(limiter.check_key(&ip).is_ok());
}
assert!(limiter.check_key(&ip).is_err());
}
#[test]
fn rate_limiter_separate_ips() {
let config = RateLimitConfig {
max_attempts_per_minute: 2,
pre_auth_max_per_minute: None,
..Default::default()
};
let limiter = build_rate_limiter(&config);
let ip1: IpAddr = "10.0.0.1".parse().unwrap();
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
assert!(limiter.check_key(&ip1).is_ok());
assert!(limiter.check_key(&ip1).is_ok());
assert!(limiter.check_key(&ip1).is_err());
assert!(limiter.check_key(&ip2).is_ok());
}
#[test]
fn extract_mtls_identity_from_cn() {
let mut params = rcgen::CertificateParams::new(vec!["test-client.local".into()]).unwrap();
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, "test-client");
let cert = params
.self_signed(&rcgen::KeyPair::generate().unwrap())
.unwrap();
let der = cert.der();
let id = extract_mtls_identity(der, "ops").unwrap();
assert_eq!(id.name, "test-client");
assert_eq!(id.role, "ops");
assert_eq!(id.method, AuthMethod::MtlsCertificate);
}
#[test]
fn extract_mtls_identity_falls_back_to_san() {
let mut params =
rcgen::CertificateParams::new(vec!["san-only.example.com".into()]).unwrap();
params.distinguished_name = rcgen::DistinguishedName::new();
let cert = params
.self_signed(&rcgen::KeyPair::generate().unwrap())
.unwrap();
let der = cert.der();
let id = extract_mtls_identity(der, "viewer").unwrap();
assert_eq!(id.name, "san-only.example.com");
assert_eq!(id.role, "viewer");
}
#[test]
fn extract_mtls_identity_invalid_der() {
assert!(extract_mtls_identity(b"not-a-cert", "viewer").is_none());
}
use axum::{
body::Body,
http::{Request, StatusCode},
};
use tower::ServiceExt as _;
fn auth_router(state: Arc<AuthState>) -> axum::Router {
axum::Router::new()
.route("/mcp", axum::routing::post(|| async { "ok" }))
.layer(axum::middleware::from_fn(move |req, next| {
let s = Arc::clone(&state);
auth_middleware(s, req, next)
}))
}
fn test_auth_state(keys: Vec<ApiKeyEntry>) -> Arc<AuthState> {
Arc::new(AuthState {
api_keys: ArcSwap::new(Arc::new(keys)),
rate_limiter: None,
pre_auth_limiter: None,
#[cfg(feature = "oauth")]
jwks_cache: None,
seen_identities: Mutex::new(HashSet::new()),
counters: AuthCounters::default(),
})
}
#[tokio::test]
async fn middleware_rejects_no_credentials() {
let state = test_auth_state(vec![]);
let app = auth_router(Arc::clone(&state));
let req = Request::builder()
.method(axum::http::Method::POST)
.uri("/mcp")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let challenge = resp
.headers()
.get(header::WWW_AUTHENTICATE)
.unwrap()
.to_str()
.unwrap();
assert!(challenge.contains("error=\"invalid_request\""));
let counters = state.counters_snapshot();
assert_eq!(counters.failure_missing_credential, 1);
}
#[tokio::test]
async fn middleware_accepts_valid_bearer() {
let (token, hash) = generate_api_key().unwrap();
let keys = vec![ApiKeyEntry {
name: "test-key".into(),
hash,
role: "ops".into(),
expires_at: None,
}];
let state = test_auth_state(keys);
let app = auth_router(Arc::clone(&state));
let req = Request::builder()
.method(axum::http::Method::POST)
.uri("/mcp")
.header("authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let counters = state.counters_snapshot();
assert_eq!(counters.success_bearer, 1);
}
#[tokio::test]
async fn middleware_rejects_wrong_bearer() {
let (_token, hash) = generate_api_key().unwrap();
let keys = vec![ApiKeyEntry {
name: "test-key".into(),
hash,
role: "ops".into(),
expires_at: None,
}];
let state = test_auth_state(keys);
let app = auth_router(Arc::clone(&state));
let req = Request::builder()
.method(axum::http::Method::POST)
.uri("/mcp")
.header("authorization", "Bearer wrong-token-here")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let challenge = resp
.headers()
.get(header::WWW_AUTHENTICATE)
.unwrap()
.to_str()
.unwrap();
assert!(challenge.contains("error=\"invalid_token\""));
let counters = state.counters_snapshot();
assert_eq!(counters.failure_invalid_credential, 1);
}
#[tokio::test]
async fn middleware_rate_limits() {
let state = Arc::new(AuthState {
api_keys: ArcSwap::new(Arc::new(vec![])),
rate_limiter: Some(build_rate_limiter(&RateLimitConfig {
max_attempts_per_minute: 1,
pre_auth_max_per_minute: None,
..Default::default()
})),
pre_auth_limiter: None,
#[cfg(feature = "oauth")]
jwks_cache: None,
seen_identities: Mutex::new(HashSet::new()),
counters: AuthCounters::default(),
});
let app = auth_router(state);
let req = Request::builder()
.method(axum::http::Method::POST)
.uri("/mcp")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[test]
fn rate_limit_semantics_failed_only() {
let config = RateLimitConfig {
max_attempts_per_minute: 3,
pre_auth_max_per_minute: None,
..Default::default()
};
let limiter = build_rate_limiter(&config);
let ip: IpAddr = "192.168.1.100".parse().unwrap();
assert!(
limiter.check_key(&ip).is_ok(),
"failure 1 should be allowed"
);
assert!(
limiter.check_key(&ip).is_ok(),
"failure 2 should be allowed"
);
assert!(
limiter.check_key(&ip).is_ok(),
"failure 3 should be allowed"
);
assert!(
limiter.check_key(&ip).is_err(),
"failure 4 should be blocked"
);
}
#[test]
fn pre_auth_default_multiplier_is_10x() {
let config = RateLimitConfig {
max_attempts_per_minute: 5,
pre_auth_max_per_minute: None,
..Default::default()
};
let limiter = build_pre_auth_limiter(&config);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
for i in 0..50 {
assert!(
limiter.check_key(&ip).is_ok(),
"pre-auth attempt {i} (of expected 50) should be allowed under default 10x multiplier"
);
}
assert!(
limiter.check_key(&ip).is_err(),
"pre-auth attempt 51 should be blocked (quota is 50, not unbounded)"
);
}
#[test]
fn pre_auth_explicit_override_wins() {
let config = RateLimitConfig {
max_attempts_per_minute: 100, pre_auth_max_per_minute: Some(2), ..Default::default()
};
let limiter = build_pre_auth_limiter(&config);
let ip: IpAddr = "10.0.0.2".parse().unwrap();
assert!(limiter.check_key(&ip).is_ok(), "attempt 1 allowed");
assert!(limiter.check_key(&ip).is_ok(), "attempt 2 allowed");
assert!(
limiter.check_key(&ip).is_err(),
"attempt 3 must be blocked (explicit override of 2 wins over 10x default of 1000)"
);
}
#[tokio::test]
async fn pre_auth_gate_blocks_before_argon2_verification() {
let (_token, hash) = generate_api_key().unwrap();
let keys = vec![ApiKeyEntry {
name: "test-key".into(),
hash,
role: "ops".into(),
expires_at: None,
}];
let config = RateLimitConfig {
max_attempts_per_minute: 100,
pre_auth_max_per_minute: Some(1),
..Default::default()
};
let state = Arc::new(AuthState {
api_keys: ArcSwap::new(Arc::new(keys)),
rate_limiter: None,
pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
#[cfg(feature = "oauth")]
jwks_cache: None,
seen_identities: Mutex::new(HashSet::new()),
counters: AuthCounters::default(),
});
let app = auth_router(Arc::clone(&state));
let peer: SocketAddr = "10.0.0.10:54321".parse().unwrap();
let mut req1 = Request::builder()
.method(axum::http::Method::POST)
.uri("/mcp")
.header("authorization", "Bearer obviously-not-a-real-token")
.body(Body::empty())
.unwrap();
req1.extensions_mut().insert(ConnectInfo(peer));
let resp1 = app.clone().oneshot(req1).await.unwrap();
assert_eq!(
resp1.status(),
StatusCode::UNAUTHORIZED,
"first attempt: gate has quota, falls through to bearer auth which fails with 401"
);
let mut req2 = Request::builder()
.method(axum::http::Method::POST)
.uri("/mcp")
.header("authorization", "Bearer also-not-a-real-token")
.body(Body::empty())
.unwrap();
req2.extensions_mut().insert(ConnectInfo(peer));
let resp2 = app.oneshot(req2).await.unwrap();
assert_eq!(
resp2.status(),
StatusCode::TOO_MANY_REQUESTS,
"second attempt from same IP: pre-auth gate must reject with 429"
);
let counters = state.counters_snapshot();
assert_eq!(
counters.failure_pre_auth_gate, 1,
"exactly one request must have been rejected by the pre-auth gate"
);
assert_eq!(
counters.failure_invalid_credential, 1,
"bearer verification must run exactly once (only the un-gated first request)"
);
}
#[tokio::test]
async fn pre_auth_gate_does_not_throttle_mtls() {
let config = RateLimitConfig {
max_attempts_per_minute: 100,
pre_auth_max_per_minute: Some(1), ..Default::default()
};
let state = Arc::new(AuthState {
api_keys: ArcSwap::new(Arc::new(vec![])),
rate_limiter: None,
pre_auth_limiter: Some(build_pre_auth_limiter(&config)),
#[cfg(feature = "oauth")]
jwks_cache: None,
seen_identities: Mutex::new(HashSet::new()),
counters: AuthCounters::default(),
});
let app = auth_router(Arc::clone(&state));
let peer: SocketAddr = "10.0.0.20:54321".parse().unwrap();
let identity = AuthIdentity {
name: "cn=test-client".into(),
role: "viewer".into(),
method: AuthMethod::MtlsCertificate,
raw_token: None,
sub: None,
};
let tls_info = TlsConnInfo::new(peer, Some(identity));
for i in 0..3 {
let mut req = Request::builder()
.method(axum::http::Method::POST)
.uri("/mcp")
.body(Body::empty())
.unwrap();
req.extensions_mut().insert(ConnectInfo(tls_info.clone()));
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(
resp.status(),
StatusCode::OK,
"mTLS request {i} must succeed: pre-auth gate must not apply to mTLS callers"
);
}
let counters = state.counters_snapshot();
assert_eq!(
counters.failure_pre_auth_gate, 0,
"pre-auth gate counter must remain at zero: mTLS bypasses the gate"
);
assert_eq!(
counters.success_mtls, 3,
"all three mTLS requests must have been counted as successful"
);
}
#[test]
fn extract_bearer_accepts_canonical_case() {
assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
}
#[test]
fn extract_bearer_is_case_insensitive_per_rfc7235() {
for header in &[
"bearer abc123",
"BEARER abc123",
"BeArEr abc123",
"bEaReR abc123",
] {
assert_eq!(
extract_bearer(header),
Some("abc123"),
"header {header:?} must parse as a Bearer token (RFC 7235 §2.1)"
);
}
}
#[test]
fn extract_bearer_rejects_other_schemes() {
assert_eq!(extract_bearer("Basic dXNlcjpwYXNz"), None);
assert_eq!(extract_bearer("Digest username=\"x\""), None);
assert_eq!(extract_bearer("Token abc123"), None);
}
#[test]
fn extract_bearer_rejects_malformed() {
assert_eq!(extract_bearer(""), None);
assert_eq!(extract_bearer("Bearer"), None);
assert_eq!(extract_bearer("Bearer "), None);
assert_eq!(extract_bearer("Bearer "), None);
}
#[test]
fn extract_bearer_tolerates_extra_separator_whitespace() {
assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
assert_eq!(extract_bearer("Bearer abc123"), Some("abc123"));
}
#[test]
fn auth_identity_debug_redacts_raw_token() {
let id = AuthIdentity {
name: "alice".into(),
role: "admin".into(),
method: AuthMethod::OAuthJwt,
raw_token: Some(SecretString::from("super-secret-jwt-payload-xyz")),
sub: Some("keycloak-uuid-2f3c8b".into()),
};
let dbg = format!("{id:?}");
assert!(dbg.contains("alice"), "name should be visible: {dbg}");
assert!(dbg.contains("admin"), "role should be visible: {dbg}");
assert!(dbg.contains("OAuthJwt"), "method should be visible: {dbg}");
assert!(
!dbg.contains("super-secret-jwt-payload-xyz"),
"raw_token must be redacted in Debug output: {dbg}"
);
assert!(
!dbg.contains("keycloak-uuid-2f3c8b"),
"sub must be redacted in Debug output: {dbg}"
);
assert!(
dbg.contains("<redacted>"),
"redaction marker missing: {dbg}"
);
}
#[test]
fn auth_identity_debug_marks_absent_secrets() {
let id = AuthIdentity {
name: "viewer-key".into(),
role: "viewer".into(),
method: AuthMethod::BearerToken,
raw_token: None,
sub: None,
};
let dbg = format!("{id:?}");
assert!(
dbg.contains("<none>"),
"absent secrets should be marked: {dbg}"
);
assert!(
!dbg.contains("<redacted>"),
"no <redacted> marker when secrets are absent: {dbg}"
);
}
#[test]
fn api_key_entry_debug_redacts_hash() {
let entry = ApiKeyEntry {
name: "viewer-key".into(),
hash: "$argon2id$v=19$m=19456,t=2,p=1$c2FsdHNhbHQ$h4sh3dPa55w0rd".into(),
role: "viewer".into(),
expires_at: Some("2030-01-01T00:00:00Z".into()),
};
let dbg = format!("{entry:?}");
assert!(dbg.contains("viewer-key"));
assert!(dbg.contains("viewer"));
assert!(dbg.contains("2030-01-01T00:00:00Z"));
assert!(
!dbg.contains("$argon2id$"),
"argon2 hash leaked into Debug output: {dbg}"
);
assert!(
!dbg.contains("h4sh3dPa55w0rd"),
"hash digest leaked into Debug output: {dbg}"
);
assert!(
dbg.contains("<redacted>"),
"redaction marker missing: {dbg}"
);
}
}