use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex as StdMutex};
use std::time::{Duration, Instant};
use arc_swap::ArcSwapOption;
use jsonwebtoken::{Algorithm, DecodingKey, Validation, jwk::JwkSet};
use tokio::sync::Semaphore;
use rusty_gasket::auth::backend::AuthBackend;
use rusty_gasket::auth::error::AuthError;
use rusty_gasket::auth::identity::Identity;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum JwksError {
#[error("JWKS endpoint returned HTTP {status}")]
HttpStatus {
status: http::StatusCode,
},
#[error("JWKS response exceeds maximum body size of {max_bytes} bytes")]
BodyTooLarge {
max_bytes: usize,
},
#[error("JWKS fetch task did not complete: {0}")]
FetchTask(#[source] tokio::task::JoinError),
}
const DEFAULT_JWKS_CACHE_TTL: Duration = Duration::from_secs(300);
const DEFAULT_JWKS_TIMEOUT: Duration = Duration::from_secs(10);
const DEFAULT_JWKS_MAX_BODY_SIZE: usize = 1 << 20;
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub enum TokenSource {
#[default]
BearerHeader,
Cookie(String),
Header(String),
}
pub trait ClaimsMapper: Send + Sync + 'static {
fn map_claims(&self, claims: &serde_json::Value) -> Result<Identity, AuthError>;
}
pub struct ClaimsMapperHandle {
mapper: Box<dyn ClaimsMapper>,
}
impl ClaimsMapperHandle {
#[must_use]
pub fn new(mapper: impl ClaimsMapper) -> Self {
Self {
mapper: Box::new(mapper),
}
}
fn map_claims(&self, claims: &serde_json::Value) -> Result<Identity, AuthError> {
self.mapper.map_claims(claims)
}
}
impl std::fmt::Debug for ClaimsMapperHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClaimsMapperHandle").finish_non_exhaustive()
}
}
#[derive(Debug, Default)]
pub struct StandardClaimsMapper;
impl ClaimsMapper for StandardClaimsMapper {
fn map_claims(&self, claims: &serde_json::Value) -> Result<Identity, AuthError> {
map_standard_claims(claims, "jwt", None)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
#[non_exhaustive]
pub enum ClientIdClaim {
Ignore,
#[default]
Preserve,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
#[non_exhaustive]
pub struct OAuthTokenClaims {
client_id: Option<String>,
}
impl OAuthTokenClaims {
#[must_use]
pub fn new(client_id: Option<String>) -> Self {
Self { client_id }
}
#[must_use]
pub fn client_id(&self) -> Option<&str> {
self.client_id.as_deref()
}
}
#[derive(Debug, Clone, Copy)]
pub struct OAuthClaimsMapper {
auth_method: &'static str,
client_id: ClientIdClaim,
}
impl Default for OAuthClaimsMapper {
fn default() -> Self {
Self {
auth_method: "oauth-jwt",
client_id: ClientIdClaim::Preserve,
}
}
}
impl OAuthClaimsMapper {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn auth_method(mut self, auth_method: &'static str) -> Self {
self.auth_method = auth_method;
self
}
#[must_use]
pub const fn preserve_client_id(mut self) -> Self {
self.client_id = ClientIdClaim::Preserve;
self
}
#[must_use]
pub const fn ignore_client_id(mut self) -> Self {
self.client_id = ClientIdClaim::Ignore;
self
}
#[must_use]
pub const fn client_id_claim(mut self, client_id: ClientIdClaim) -> Self {
self.client_id = client_id;
self
}
}
impl ClaimsMapper for OAuthClaimsMapper {
fn map_claims(&self, claims: &serde_json::Value) -> Result<Identity, AuthError> {
let client_id = match self.client_id {
ClientIdClaim::Ignore => None,
ClientIdClaim::Preserve => claims
.get("client_id")
.and_then(|claim| claim.as_str())
.map(String::from),
};
let oauth_claims =
(self.client_id == ClientIdClaim::Preserve).then(|| OAuthTokenClaims::new(client_id));
map_standard_claims(claims, self.auth_method, oauth_claims)
}
}
fn map_standard_claims(
claims: &serde_json::Value,
auth_method: &'static str,
oauth_claims: Option<OAuthTokenClaims>,
) -> Result<Identity, AuthError> {
let subject = claims["sub"]
.as_str()
.ok_or_else(|| AuthError::TokenValidation("Missing 'sub' claim".to_string()))?
.to_string();
let scopes = scope_claims(claims);
let display_name = claims
.get("name")
.and_then(|claim| claim.as_str())
.map(String::from);
let mut builder = Identity::builder(subject, auth_method).scopes(scopes);
if let Some(name) = display_name {
builder = builder.display_name(name);
}
if let Some(claims) = oauth_claims {
builder = builder.attribute(claims);
}
Ok(builder.build())
}
fn scope_claims(claims: &serde_json::Value) -> HashSet<String> {
claims
.get("scope")
.and_then(|claim| claim.as_str())
.map(|scope| scope.split_whitespace().map(String::from).collect())
.unwrap_or_default()
}
enum KeyResolver {
Static(Box<StaticKey>),
Jwks(Arc<JwksKeyStore>),
}
struct StaticKey {
key: DecodingKey,
validation: Validation,
}
struct JwksKeyStore {
url: String,
cache_ttl: Duration,
cached: ArcSwapOption<PreparedJwks>,
fetch_lock: Semaphore,
http_client: reqwest::Client,
max_body_size: usize,
allowed_algorithms: Vec<Algorithm>,
audience: Option<String>,
issuer: Option<String>,
validate_exp: bool,
validate_nbf: bool,
leeway_secs: u64,
warned_drops: StdMutex<HashSet<String>>,
}
struct PreparedJwks {
fetched_at: Instant,
keys: HashMap<String, Arc<PreparedKey>>,
}
struct PreparedKey {
key: DecodingKey,
algorithm: Algorithm,
}
enum CacheLookup {
Hit(Arc<PreparedKey>),
FreshButMissing,
Stale,
}
fn kid_not_found(kid: &str) -> AuthError {
AuthError::TokenValidation(format!("Key '{}' not found in JWKS", sanitize_for_log(kid)))
}
fn sanitize_for_log(s: &str) -> String {
const MAX: usize = 128;
s.chars().filter(|c| !c.is_control()).take(MAX).collect()
}
fn strip_bearer_prefix(s: &str) -> &str {
let trimmed = s.trim_start();
if trimmed.len() >= 7
&& trimmed.as_bytes()[..6].eq_ignore_ascii_case(b"bearer")
&& matches!(trimmed.as_bytes()[6], b' ' | b'\t')
{
trimmed[7..].trim_start()
} else {
trimmed
}
}
impl JwksKeyStore {
async fn resolve_key(
self: &Arc<Self>,
token: &str,
) -> Result<(DecodingKey, Validation), AuthError> {
let header = jsonwebtoken::decode_header(token)
.map_err(|e| AuthError::TokenValidation(format!("Invalid JWT header: {e}")))?;
let kid = header
.kid
.as_deref()
.ok_or_else(|| AuthError::TokenValidation("JWT missing 'kid' header".to_string()))?;
if !self.allowed_algorithms.contains(&header.alg) {
return Err(AuthError::TokenValidation(format!(
"Algorithm {:?} is not in the JWKS allow-list",
header.alg
)));
}
let prepared = self.find_key(kid).await?;
if prepared.algorithm != header.alg {
return Err(AuthError::TokenValidation(format!(
"JWT alg {:?} does not match JWKS key alg {:?}",
header.alg, prepared.algorithm
)));
}
let mut validation = Validation::new(prepared.algorithm);
validation.validate_exp = self.validate_exp;
validation.validate_nbf = self.validate_nbf;
validation.leeway = self.leeway_secs;
if !self.validate_exp {
validation.required_spec_claims.remove("exp");
}
if let Some(ref aud) = self.audience {
validation.set_audience(&[aud]);
}
if let Some(ref iss) = self.issuer {
validation.set_issuer(&[iss]);
}
Ok((prepared.key.clone(), validation))
}
async fn find_key(self: &Arc<Self>, kid: &str) -> Result<Arc<PreparedKey>, AuthError> {
match self.cache_lookup(kid) {
CacheLookup::Hit(prepared) => return Ok(prepared),
CacheLookup::FreshButMissing => return Err(kid_not_found(kid)),
CacheLookup::Stale => {}
}
let _fetch_guard = self
.fetch_lock
.acquire()
.await
.map_err(|e| AuthError::BackendError(Box::new(e)))?;
match self.cache_lookup(kid) {
CacheLookup::Hit(prepared) => return Ok(prepared),
CacheLookup::FreshButMissing => return Err(kid_not_found(kid)),
CacheLookup::Stale => {}
}
let store = Arc::clone(self);
let fetch_task = tokio::spawn(async move {
let jwks = store.fetch_jwks().await?;
let prepared = Arc::new(prepare_jwks(&jwks, &store.warned_drops));
store.cached.store(Some(Arc::clone(&prepared)));
Ok::<_, AuthError>(prepared)
});
let prepared = match fetch_task.await {
Ok(Ok(prepared)) => prepared,
Ok(Err(e)) => return Err(e),
Err(join_err) => {
return Err(AuthError::BackendError(Box::new(JwksError::FetchTask(
join_err,
))));
}
};
prepared
.keys
.get(kid)
.cloned()
.ok_or_else(|| kid_not_found(kid))
}
fn cache_lookup(&self, kid: &str) -> CacheLookup {
let snapshot = self.cached.load();
let Some(prepared) = snapshot.as_ref() else {
return CacheLookup::Stale;
};
if prepared.fetched_at.elapsed() >= self.cache_ttl {
return CacheLookup::Stale;
}
match prepared.keys.get(kid) {
Some(key) => CacheLookup::Hit(Arc::clone(key)),
None => CacheLookup::FreshButMissing,
}
}
async fn fetch_jwks(&self) -> Result<JwkSet, AuthError> {
tracing::debug!(url = %self.url, "Fetching JWKS");
let response = self
.http_client
.get(&self.url)
.send()
.await
.map_err(|e| AuthError::BackendError(Box::new(e)))?;
if !response.status().is_success() {
return Err(AuthError::BackendError(Box::new(JwksError::HttpStatus {
status: response.status(),
})));
}
let mut body = Vec::with_capacity(4096);
let mut stream = response;
while let Some(chunk) = stream
.chunk()
.await
.map_err(|e| AuthError::BackendError(Box::new(e)))?
{
if body.len().saturating_add(chunk.len()) > self.max_body_size {
return Err(AuthError::BackendError(Box::new(JwksError::BodyTooLarge {
max_bytes: self.max_body_size,
})));
}
body.extend_from_slice(&chunk);
}
serde_json::from_slice::<JwkSet>(&body).map_err(|e| AuthError::BackendError(Box::new(e)))
}
}
fn prepare_jwks(jwks: &JwkSet, warned_drops: &StdMutex<HashSet<String>>) -> PreparedJwks {
let mut keys: HashMap<String, Arc<PreparedKey>> = HashMap::with_capacity(jwks.keys.len());
for jwk in &jwks.keys {
let Some(kid) = jwk.common.key_id.as_deref() else {
warn_once(warned_drops, "<no-kid>", "no `kid`", || {
tracing::warn!("Skipping JWKS key with no `kid`");
});
continue;
};
let Some(algorithm) = jwk
.common
.key_algorithm
.and_then(|a| a.to_string().parse::<Algorithm>().ok())
else {
warn_once(warned_drops, kid, "no-usable-alg", || {
tracing::warn!(
kid = %sanitize_for_log(kid),
"Skipping JWKS key with no usable `alg`",
);
});
continue;
};
match DecodingKey::from_jwk(jwk) {
Ok(key) => {
keys.insert(kid.to_string(), Arc::new(PreparedKey { key, algorithm }));
}
Err(e) => {
warn_once(warned_drops, kid, "decode-key-failed", || {
tracing::warn!(
kid = %sanitize_for_log(kid),
error = %e,
"Failed to build DecodingKey from JWK; skipping",
);
});
}
}
}
PreparedJwks {
fetched_at: Instant::now(),
keys,
}
}
fn warn_once(
warned: &StdMutex<HashSet<String>>,
kid: &str,
reason: &'static str,
emit: impl FnOnce(),
) {
let key = format!("{reason}:{kid}");
let mut guard = match warned.lock() {
Ok(g) => g,
Err(p) => p.into_inner(),
};
if guard.insert(key) {
emit();
}
}
pub struct JwtBackend {
key_resolver: KeyResolver,
token_source: TokenSource,
claims_mapper: ClaimsMapperHandle,
}
impl std::fmt::Debug for JwtBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mode = match &self.key_resolver {
KeyResolver::Static(_) => "static",
KeyResolver::Jwks(_) => "jwks",
};
f.debug_struct("JwtBackend")
.field("key_mode", &mode)
.field("token_source", &self.token_source)
.finish_non_exhaustive()
}
}
impl JwtBackend {
pub fn builder() -> JwtBackendBuilder {
JwtBackendBuilder::default()
}
fn extract_token(&self, headers: &http::HeaderMap) -> Option<String> {
match &self.token_source {
TokenSource::BearerHeader => {
let auth = headers.get(http::header::AUTHORIZATION)?;
let auth_str = auth.to_str().ok()?;
rusty_gasket::auth::backend::extract_bearer_token(auth_str).map(String::from)
}
TokenSource::Cookie(name) => {
let cookies = headers.get(http::header::COOKIE)?;
let cookies_str = cookies.to_str().ok()?;
let prefix = format!("{name}=");
let mut found: Option<String> = None;
for pair in cookies_str.split(';') {
if let Some(value) = pair.trim().strip_prefix(&prefix) {
if found.is_some() {
tracing::warn!(
cookie_name = %name,
"Multiple cookies with the same name; refusing to extract a token"
);
return None;
}
found = Some(value.to_string());
}
}
found
}
TokenSource::Header(name) => {
let value = headers.get(name)?;
let raw = value.to_str().ok()?;
Some(strip_bearer_prefix(raw).to_string())
}
}
}
fn decode_jwt_error(e: &jsonwebtoken::errors::Error) -> AuthError {
match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
jsonwebtoken::errors::ErrorKind::InvalidToken
| jsonwebtoken::errors::ErrorKind::InvalidSignature => {
AuthError::InvalidCredentials(format!("Invalid JWT: {e}"))
}
_ => AuthError::TokenValidation(format!("JWT validation failed: {e}")),
}
}
}
impl AuthBackend for JwtBackend {
fn name(&self) -> &'static str {
"jwt"
}
async fn authenticate(
&self,
headers: &http::HeaderMap,
_uri: &http::Uri,
) -> Result<Option<Identity>, AuthError> {
let token = match self.extract_token(headers) {
Some(t) => t,
None => return Ok(None),
};
let token_data = match &self.key_resolver {
KeyResolver::Static(sk) => {
jsonwebtoken::decode::<serde_json::Value>(&token, &sk.key, &sk.validation)
.map_err(|e| Self::decode_jwt_error(&e))?
}
KeyResolver::Jwks(store) => {
let (key, validation) = store.resolve_key(&token).await?;
jsonwebtoken::decode::<serde_json::Value>(&token, &key, &validation)
.map_err(|e| Self::decode_jwt_error(&e))?
}
};
let identity = self.claims_mapper.map_claims(&token_data.claims)?;
Ok(Some(identity))
}
}
enum StaticKeySource {
Hmac(Vec<u8>),
RsaPem(Vec<u8>),
EcPem(Vec<u8>),
}
#[must_use = "JwtBackendBuilder must be consumed by .build() to produce a backend"]
pub struct JwtBackendBuilder {
static_key: Option<StaticKeySource>,
jwks_url: Option<String>,
algorithm: Algorithm,
token_source: TokenSource,
claims_mapper: Option<ClaimsMapperHandle>,
audience: Option<String>,
issuer: Option<String>,
validate_exp: bool,
validate_nbf: bool,
leeway_secs: u64,
jwks_cache_ttl: Duration,
jwks_timeout: Duration,
jwks_max_body_size: usize,
jwks_allowed_algorithms: Option<Vec<Algorithm>>,
jwks_allow_http: bool,
}
impl Default for JwtBackendBuilder {
fn default() -> Self {
Self {
static_key: None,
jwks_url: None,
algorithm: Algorithm::HS256,
token_source: TokenSource::BearerHeader,
claims_mapper: None,
audience: None,
issuer: None,
validate_exp: true,
validate_nbf: true,
leeway_secs: 0,
jwks_cache_ttl: DEFAULT_JWKS_CACHE_TTL,
jwks_timeout: DEFAULT_JWKS_TIMEOUT,
jwks_max_body_size: DEFAULT_JWKS_MAX_BODY_SIZE,
jwks_allowed_algorithms: None,
jwks_allow_http: false,
}
}
}
impl std::fmt::Debug for JwtBackendBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtBackendBuilder")
.field("algorithm", &self.algorithm)
.field("token_source", &self.token_source)
.field("validate_exp", &self.validate_exp)
.field("jwks_url", &self.jwks_url)
.field("jwks_cache_ttl", &self.jwks_cache_ttl)
.finish_non_exhaustive()
}
}
impl JwtBackendBuilder {
fn default_jwks_algorithms() -> Vec<Algorithm> {
vec![
Algorithm::RS256,
Algorithm::RS384,
Algorithm::RS512,
Algorithm::PS256,
Algorithm::PS384,
Algorithm::PS512,
Algorithm::ES256,
Algorithm::ES384,
Algorithm::EdDSA,
]
}
pub fn hmac_secret(mut self, secret: impl Into<Vec<u8>>) -> Self {
self.static_key = Some(StaticKeySource::Hmac(secret.into()));
self
}
pub fn rsa_pem(mut self, pem: impl Into<Vec<u8>>) -> Self {
self.static_key = Some(StaticKeySource::RsaPem(pem.into()));
self.algorithm = Algorithm::RS256;
self
}
pub fn ec_pem(mut self, pem: impl Into<Vec<u8>>) -> Self {
self.static_key = Some(StaticKeySource::EcPem(pem.into()));
self.algorithm = Algorithm::ES256;
self
}
pub fn jwks_url(mut self, url: impl Into<String>) -> Self {
self.jwks_url = Some(url.into());
self
}
pub const fn jwks_cache_ttl(mut self, ttl: Duration) -> Self {
self.jwks_cache_ttl = ttl;
self
}
pub const fn jwks_timeout(mut self, timeout: Duration) -> Self {
self.jwks_timeout = timeout;
self
}
pub const fn jwks_max_body_size(mut self, max_bytes: usize) -> Self {
self.jwks_max_body_size = max_bytes;
self
}
pub fn jwks_allowed_algorithms(
mut self,
algorithms: impl IntoIterator<Item = Algorithm>,
) -> Self {
self.jwks_allowed_algorithms = Some(algorithms.into_iter().collect());
self
}
pub const fn jwks_allow_http(mut self, allow: bool) -> Self {
self.jwks_allow_http = allow;
self
}
pub const fn algorithm(mut self, algorithm: Algorithm) -> Self {
self.algorithm = algorithm;
self
}
pub fn token_source(mut self, source: TokenSource) -> Self {
self.token_source = source;
self
}
pub fn claims_mapper(mut self, mapper: impl ClaimsMapper) -> Self {
self.claims_mapper = Some(ClaimsMapperHandle::new(mapper));
self
}
pub fn claims_mapper_handle(mut self, mapper: ClaimsMapperHandle) -> Self {
self.claims_mapper = Some(mapper);
self
}
pub fn audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = Some(issuer.into());
self
}
pub const fn validate_exp(mut self, validate: bool) -> Self {
self.validate_exp = validate;
self
}
pub const fn validate_nbf(mut self, validate: bool) -> Self {
self.validate_nbf = validate;
self
}
pub const fn leeway_secs(mut self, secs: u64) -> Self {
self.leeway_secs = secs;
self
}
pub fn build(self) -> Result<JwtBackend, AuthError> {
let key_resolver = match (self.static_key.as_ref(), self.jwks_url.as_deref()) {
(None, None) => {
return Err(AuthError::Configuration(
"JWT key source is required (call hmac_secret, rsa_pem, ec_pem, or jwks_url)"
.to_string(),
));
}
(Some(_), Some(_)) => {
return Err(AuthError::Configuration(
"JWT backend cannot use both a static key and a JWKS URL".to_string(),
));
}
(Some(static_key), None) => Self::build_static_resolver(static_key, &self)?,
(None, Some(url)) => Self::build_jwks_resolver(url, &self)?,
};
Ok(JwtBackend {
key_resolver,
token_source: self.token_source,
claims_mapper: self
.claims_mapper
.unwrap_or_else(|| ClaimsMapperHandle::new(StandardClaimsMapper)),
})
}
fn build_static_resolver(
source: &StaticKeySource,
builder: &Self,
) -> Result<KeyResolver, AuthError> {
let validation = builder.build_validation();
let static_key = match source {
StaticKeySource::Hmac(secret) => {
let compatible = matches!(
builder.algorithm,
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512
);
if !compatible {
return Err(AuthError::Configuration(format!(
"Algorithm {:?} is incompatible with HMAC secret",
builder.algorithm
)));
}
StaticKey {
key: DecodingKey::from_secret(secret),
validation,
}
}
StaticKeySource::RsaPem(pem) => {
let compatible = matches!(
builder.algorithm,
Algorithm::RS256
| Algorithm::RS384
| Algorithm::RS512
| Algorithm::PS256
| Algorithm::PS384
| Algorithm::PS512
);
if !compatible {
return Err(AuthError::Configuration(format!(
"Algorithm {:?} is incompatible with RSA PEM key",
builder.algorithm
)));
}
let key = DecodingKey::from_rsa_pem(pem)
.map_err(|e| AuthError::Configuration(format!("Invalid RSA PEM: {e}")))?;
StaticKey { key, validation }
}
StaticKeySource::EcPem(pem) => {
let compatible = matches!(builder.algorithm, Algorithm::ES256 | Algorithm::ES384);
if !compatible {
return Err(AuthError::Configuration(format!(
"Algorithm {:?} is incompatible with EC PEM key",
builder.algorithm
)));
}
let key = DecodingKey::from_ec_pem(pem)
.map_err(|e| AuthError::Configuration(format!("Invalid EC PEM: {e}")))?;
StaticKey { key, validation }
}
};
Ok(KeyResolver::Static(Box::new(static_key)))
}
fn build_jwks_resolver(url: &str, builder: &Self) -> Result<KeyResolver, AuthError> {
let parsed = ::url::Url::parse(url).map_err(|e| {
AuthError::Configuration(format!(
"JWKS URL is not a valid absolute URL ({e}); got {url:?}"
))
})?;
if parsed.host_str().is_none_or(str::is_empty) {
return Err(AuthError::Configuration(format!(
"JWKS URL has no host: {url:?}"
)));
}
match parsed.scheme() {
"https" => {}
"http" if builder.jwks_allow_http => {}
"http" => {
return Err(AuthError::Configuration(format!(
"JWKS URL must use https:// (got {url:?}); call jwks_allow_http(true) to opt out"
)));
}
other => {
return Err(AuthError::Configuration(format!(
"JWKS URL scheme must be http or https (got {other:?})"
)));
}
}
let http_client = reqwest::Client::builder()
.timeout(builder.jwks_timeout)
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| {
AuthError::Configuration(format!("Failed to build JWKS HTTP client: {e}"))
})?;
let allowed_algorithms = builder
.jwks_allowed_algorithms
.clone()
.unwrap_or_else(Self::default_jwks_algorithms);
if allowed_algorithms.is_empty() {
return Err(AuthError::Configuration(
"JWKS allowed_algorithms must not be empty".to_string(),
));
}
let store = JwksKeyStore {
url: url.to_string(),
cache_ttl: builder.jwks_cache_ttl,
cached: ArcSwapOption::const_empty(),
fetch_lock: Semaphore::new(1),
http_client,
max_body_size: builder.jwks_max_body_size,
allowed_algorithms,
audience: builder.audience.clone(),
issuer: builder.issuer.clone(),
validate_exp: builder.validate_exp,
validate_nbf: builder.validate_nbf,
leeway_secs: builder.leeway_secs,
warned_drops: StdMutex::new(HashSet::new()),
};
Ok(KeyResolver::Jwks(Arc::new(store)))
}
fn build_validation(&self) -> Validation {
let mut validation = Validation::new(self.algorithm);
validation.validate_exp = self.validate_exp;
validation.validate_nbf = self.validate_nbf;
validation.leeway = self.leeway_secs;
if !self.validate_exp {
validation.required_spec_claims.remove("exp");
}
if let Some(ref aud) = self.audience {
validation.set_audience(&[aud]);
}
if let Some(ref iss) = self.issuer {
validation.set_issuer(&[iss]);
}
validation
}
}
#[cfg(test)]
#[allow(clippy::panic)]
mod tests {
use super::*;
#[test]
fn sanitize_for_log_strips_control_bytes_and_truncates() {
let attacker = "bad\nINFO fake_user logged_in\r\nx\u{1b}[1mboldhack\u{1b}[0m";
let sanitized = sanitize_for_log(attacker);
assert!(!sanitized.contains('\n'));
assert!(!sanitized.contains('\r'));
assert!(!sanitized.contains('\u{1b}'));
assert!(sanitized.starts_with("bad"));
let big: String = std::iter::repeat_n('a', 1024).collect();
assert_eq!(sanitize_for_log(&big).len(), 128);
}
#[test]
fn kid_not_found_error_does_not_contain_control_bytes() {
let err = kid_not_found("evil\nfake\r\nlog line");
let s = err.to_string();
assert!(!s.contains('\n'), "rendered error contains newline: {s:?}");
assert!(!s.contains('\r'), "rendered error contains CR: {s:?}");
}
#[test]
fn strip_bearer_prefix_handles_common_cases() {
assert_eq!(strip_bearer_prefix("eyJabc.def.ghi"), "eyJabc.def.ghi");
assert_eq!(
strip_bearer_prefix("Bearer eyJabc.def.ghi"),
"eyJabc.def.ghi"
);
assert_eq!(
strip_bearer_prefix("bearer eyJabc.def.ghi"),
"eyJabc.def.ghi"
);
assert_eq!(
strip_bearer_prefix("BEARER eyJabc.def.ghi"),
"eyJabc.def.ghi"
);
assert_eq!(
strip_bearer_prefix(" Bearer eyJabc.def.ghi"),
"eyJabc.def.ghi"
);
assert_eq!(
strip_bearer_prefix("Bearer eyJabc.def.ghi"),
"eyJabc.def.ghi"
);
assert_eq!(strip_bearer_prefix("BearerLikeWord"), "BearerLikeWord");
}
#[test]
fn oauth_claims_mapper_preserves_client_id_by_default() {
let claims = serde_json::json!({
"sub": "oauth_client:svc-a",
"client_id": "svc-a",
"scope": "read write",
"name": "Service A",
});
let identity = OAuthClaimsMapper::new()
.map_claims(&claims)
.expect("map oauth claims");
assert_eq!(identity.subject(), "oauth_client:svc-a");
assert_eq!(identity.auth_method(), "oauth-jwt");
assert_eq!(identity.display_name(), Some("Service A"));
assert!(identity.has_scope("read"));
assert!(identity.has_scope("write"));
assert_eq!(
identity
.attributes()
.get::<OAuthTokenClaims>()
.and_then(OAuthTokenClaims::client_id),
Some("svc-a")
);
}
#[test]
fn oauth_claims_mapper_can_ignore_client_id() {
let claims = serde_json::json!({
"sub": "oauth_client:svc-a",
"client_id": "svc-a",
});
let identity = OAuthClaimsMapper::new()
.ignore_client_id()
.map_claims(&claims)
.expect("map oauth claims");
assert!(identity.attributes().get::<OAuthTokenClaims>().is_none());
}
fn make_test_backend() -> JwtBackend {
JwtBackend::builder()
.hmac_secret(b"test-secret-key-that-is-long-enough")
.validate_exp(false)
.build()
.expect("valid builder")
}
fn make_test_token(claims: &serde_json::Value) -> String {
let header = jsonwebtoken::Header::new(Algorithm::HS256);
let key = jsonwebtoken::EncodingKey::from_secret(b"test-secret-key-that-is-long-enough");
jsonwebtoken::encode(&header, claims, &key).expect("encode token")
}
#[tokio::test]
async fn valid_bearer_token() {
let backend = make_test_backend();
let token = make_test_token(&serde_json::json!({
"sub": "user-123",
"scope": "read write",
"name": "Test User",
}));
let mut headers = http::HeaderMap::new();
headers.insert(
http::header::AUTHORIZATION,
format!("Bearer {token}").parse().expect("valid header"),
);
let result = backend
.authenticate(&headers, &"/test".parse().expect("valid uri"))
.await
.expect("should succeed");
let identity = result.expect("should have identity");
assert_eq!(identity.subject(), "user-123");
assert_eq!(identity.auth_method(), "jwt");
assert_eq!(identity.display_name(), Some("Test User"));
assert!(identity.has_scope("read"));
assert!(identity.has_scope("write"));
assert!(!identity.has_scope("admin"));
}
#[tokio::test]
async fn no_auth_header_returns_none() {
let backend = make_test_backend();
let headers = http::HeaderMap::new();
let result = backend
.authenticate(&headers, &"/test".parse().expect("valid uri"))
.await
.expect("should succeed");
assert!(result.is_none());
}
#[tokio::test]
async fn invalid_token_returns_error() {
let backend = make_test_backend();
let mut headers = http::HeaderMap::new();
headers.insert(
http::header::AUTHORIZATION,
"Bearer not-a-valid-jwt".parse().expect("valid header"),
);
let result = backend
.authenticate(&headers, &"/test".parse().expect("valid uri"))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn wrong_secret_returns_error() {
let backend = JwtBackend::builder()
.hmac_secret(b"different-secret-key-that-is-long-enough")
.validate_exp(false)
.build()
.expect("valid builder");
let token = make_test_token(&serde_json::json!({"sub": "user-123"}));
let mut headers = http::HeaderMap::new();
headers.insert(
http::header::AUTHORIZATION,
format!("Bearer {token}").parse().expect("valid header"),
);
let result = backend
.authenticate(&headers, &"/test".parse().expect("valid uri"))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn missing_sub_claim_returns_error() {
let backend = make_test_backend();
let token = make_test_token(&serde_json::json!({"scope": "read"}));
let mut headers = http::HeaderMap::new();
headers.insert(
http::header::AUTHORIZATION,
format!("Bearer {token}").parse().expect("valid header"),
);
let result = backend
.authenticate(&headers, &"/test".parse().expect("valid uri"))
.await;
assert!(result.is_err());
}
#[test]
fn builder_requires_key_source() {
let result = JwtBackend::builder().build();
assert!(result.is_err());
}
#[test]
fn builder_rejects_static_and_jwks() {
let result = JwtBackend::builder()
.hmac_secret(b"a-very-long-test-secret-key-please")
.jwks_url("https://example.com/jwks.json")
.build();
assert!(result.is_err());
}
#[test]
fn jwks_url_requires_https_by_default() {
let result = JwtBackend::builder()
.jwks_url("http://example.com/jwks.json")
.build();
assert!(result.is_err());
}
#[test]
fn jwks_url_rejects_unknown_scheme() {
let result = JwtBackend::builder()
.jwks_url("file:///etc/passwd")
.jwks_allow_http(true)
.build();
assert!(result.is_err());
}
#[test]
fn jwks_url_allows_http_when_opted_in() {
let result = JwtBackend::builder()
.jwks_url("http://localhost:8080/jwks.json")
.jwks_allow_http(true)
.build();
assert!(result.is_ok());
}
#[test]
fn jwks_empty_allowed_algorithms_rejected() {
let result = JwtBackend::builder()
.jwks_url("https://example.com/jwks.json")
.jwks_allowed_algorithms(vec![])
.build();
assert!(result.is_err());
}
#[tokio::test]
async fn cookie_token_source() {
let backend = JwtBackend::builder()
.hmac_secret(b"test-secret-key-that-is-long-enough")
.token_source(TokenSource::Cookie("auth_token".to_string()))
.validate_exp(false)
.build()
.expect("valid builder");
let token = make_test_token(&serde_json::json!({"sub": "cookie-user"}));
let mut headers = http::HeaderMap::new();
headers.insert(
http::header::COOKIE,
format!("auth_token={token}; other=value")
.parse()
.expect("valid header"),
);
let result = backend
.authenticate(&headers, &"/test".parse().expect("valid uri"))
.await
.expect("should succeed");
let identity = result.expect("should have identity");
assert_eq!(identity.subject(), "cookie-user");
}
#[test]
fn jwks_builder_creates_backend() {
let backend = JwtBackend::builder()
.jwks_url("https://auth.example.com/.well-known/jwks.json")
.audience("my-api")
.issuer("https://auth.example.com")
.build();
assert!(backend.is_ok());
let backend = backend.expect("should build");
assert!(matches!(backend.key_resolver, KeyResolver::Jwks(_)));
}
#[test]
fn jwks_cache_ttl_takes_effect_regardless_of_call_order() {
let backend = JwtBackend::builder()
.jwks_cache_ttl(Duration::from_secs(42))
.jwks_url("https://auth.example.com/jwks.json")
.build()
.expect("valid builder");
let KeyResolver::Jwks(store) = &backend.key_resolver else {
panic!("expected JWKS resolver");
};
assert_eq!(store.cache_ttl, Duration::from_secs(42));
}
}