use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use argon2::{Argon2, PasswordHash, PasswordVerifier};
use axum::http::{header, HeaderMap, HeaderName};
use base64::{engine::general_purpose::STANDARD as B64, Engine};
use jsonwebtoken::jwk::JwkSet;
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use serde_json::Value;
use tokio::sync::RwLock;
use tracing::warn;
use crate::config::{AuthCfg, JwtCfg};
pub enum Decision {
Allow(Option<String>),
Deny(Challenge),
}
pub enum Challenge {
Basic(String),
Bearer,
None,
}
pub enum AuthEngine {
Open,
Basic,
ApiKey {
keys: Vec<String>,
header: HeaderName,
},
Jwt(Box<JwtValidator>),
}
impl AuthEngine {
pub fn build(cfg: &AuthCfg) -> Result<AuthEngine> {
match cfg.mode.as_str() {
"none" => Ok(AuthEngine::Open),
"basic" => Ok(AuthEngine::Basic),
"apikey" => {
let header = HeaderName::from_bytes(cfg.api_key_header.as_bytes())
.context("invalid auth.api_key_header")?;
Ok(AuthEngine::ApiKey {
keys: cfg.api_keys.clone(),
header,
})
}
"jwt" => Ok(AuthEngine::Jwt(Box::new(JwtValidator::build(&cfg.jwt)?))),
other => anyhow::bail!("unknown auth.mode: {other:?} (expected none|basic|apikey|jwt)"),
}
}
pub async fn authorize(&self, cfg: &AuthCfg, headers: &HeaderMap) -> Decision {
match self {
AuthEngine::Open => Decision::Allow(None),
AuthEngine::Basic => {
if check_basic_auth(cfg, headers) {
Decision::Allow(basic_username(headers))
} else {
Decision::Deny(Challenge::Basic(format!("Basic realm=\"{}\"", cfg.realm)))
}
}
AuthEngine::ApiKey { keys, header } => match verify_api_key(keys, header, headers) {
Some(principal) => Decision::Allow(Some(principal)),
None => Decision::Deny(Challenge::None),
},
AuthEngine::Jwt(v) => match bearer_token(headers) {
Some(token) => match v.verify(token).await {
Ok(principal) => Decision::Allow(principal),
Err(_) => Decision::Deny(Challenge::Bearer),
},
None => Decision::Deny(Challenge::Bearer),
},
}
}
}
pub fn check_basic_auth(cfg: &AuthCfg, headers: &HeaderMap) -> bool {
let Some((user, pass)) = basic_credentials(headers) else {
return false;
};
let Some(stored) = cfg.users.get(&user) else {
return false;
};
if stored.starts_with("$argon2") {
match PasswordHash::new(stored) {
Ok(parsed) => Argon2::default()
.verify_password(pass.as_bytes(), &parsed)
.is_ok(),
Err(_) => false,
}
} else {
constant_time_eq(stored.as_bytes(), pass.as_bytes())
}
}
fn basic_credentials(headers: &HeaderMap) -> Option<(String, String)> {
let auth = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
let b64 = auth.strip_prefix("Basic ")?;
let decoded = B64.decode(b64.trim()).ok()?;
let creds = String::from_utf8(decoded).ok()?;
let (user, pass) = creds.split_once(':')?;
Some((user.to_string(), pass.to_string()))
}
fn basic_username(headers: &HeaderMap) -> Option<String> {
basic_credentials(headers).map(|(u, _)| u)
}
fn bearer_token(headers: &HeaderMap) -> Option<&str> {
headers
.get(header::AUTHORIZATION)?
.to_str()
.ok()?
.strip_prefix("Bearer ")
.map(str::trim)
}
pub fn verify_api_key(keys: &[String], header: &HeaderName, headers: &HeaderMap) -> Option<String> {
let presented = headers
.get(header)
.and_then(|v| v.to_str().ok())
.map(str::trim)
.or_else(|| bearer_token(headers))?;
let mut matched: Option<&String> = None;
for key in keys {
if constant_time_eq(key.as_bytes(), presented.as_bytes()) {
matched = Some(key);
}
}
matched.map(|k| format!("apikey:{}", short_id(k)))
}
fn short_id(secret: &str) -> String {
use std::hash::{Hash, Hasher};
let mut h = std::collections::hash_map::DefaultHasher::new();
secret.hash(&mut h);
format!("{:016x}", h.finish())
}
pub fn hash_password(password: &str) -> Result<String> {
use argon2::password_hash::rand_core::OsRng;
use argon2::password_hash::{PasswordHasher, SaltString};
let salt = SaltString::generate(&mut OsRng);
Argon2::default()
.hash_password(password.as_bytes(), &salt)
.map(|h| h.to_string())
.map_err(|e| anyhow::anyhow!("hashing password: {e}"))
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
let mut diff = a.len() ^ b.len();
let max_len = a.len().max(b.len());
for i in 0..max_len {
let x = a.get(i).copied().unwrap_or(0);
let y = b.get(i).copied().unwrap_or(0);
diff |= usize::from(x ^ y);
}
diff == 0
}
type Principal = Option<String>;
pub struct JwtValidator {
alg: Algorithm,
validation: Validation,
keys: KeySource,
}
enum KeySource {
Static(Arc<DecodingKey>),
Jwks(JwksCache),
}
impl JwtValidator {
pub fn build(cfg: &JwtCfg) -> Result<JwtValidator> {
let alg = parse_algorithm(&cfg.algorithm)?;
let mut validation = Validation::new(alg);
validation.leeway = cfg.leeway_secs;
validation.validate_nbf = true;
if !cfg.issuer.is_empty() {
validation.set_issuer(std::slice::from_ref(&cfg.issuer));
}
if cfg.audience.is_empty() {
validation.validate_aud = false;
} else {
validation.set_audience(std::slice::from_ref(&cfg.audience));
}
let keys = if !cfg.jwks_url.is_empty() {
KeySource::Jwks(JwksCache::new(
cfg.jwks_url.clone(),
Duration::from_secs(cfg.jwks_cache_secs),
)?)
} else {
KeySource::Static(Arc::new(static_key(cfg, alg)?))
};
Ok(JwtValidator {
alg,
validation,
keys,
})
}
pub async fn verify(&self, token: &str) -> Result<Principal> {
let header = decode_header(token).context("malformed JWT header")?;
anyhow::ensure!(
header.alg == self.alg,
"token alg {:?} != configured {:?}",
header.alg,
self.alg
);
let key = match &self.keys {
KeySource::Static(k) => k.clone(),
KeySource::Jwks(cache) => cache.key_for(header.kid.as_deref()).await?,
};
let data = decode::<Value>(token, &key, &self.validation).context("JWT rejected")?;
let principal = data
.claims
.get("sub")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
Ok(principal)
}
}
fn static_key(cfg: &JwtCfg, alg: Algorithm) -> Result<DecodingKey> {
use Algorithm::*;
match alg {
HS256 | HS384 | HS512 => {
anyhow::ensure!(
!cfg.secret.is_empty(),
"auth.jwt.secret (or $EDGEGUARD_JWT_SECRET) is required for HS* algorithms"
);
Ok(DecodingKey::from_secret(cfg.secret.as_bytes()))
}
RS256 | RS384 | RS512 | PS256 | PS384 | PS512 => {
anyhow::ensure!(
!cfg.public_key_pem.is_empty(),
"auth.jwt.public_key_pem (or jwks_url) is required for RS*/PS* algorithms"
);
DecodingKey::from_rsa_pem(cfg.public_key_pem.as_bytes())
.context("parsing auth.jwt.public_key_pem as RSA")
}
ES256 | ES384 => {
anyhow::ensure!(
!cfg.public_key_pem.is_empty(),
"auth.jwt.public_key_pem (or jwks_url) is required for ES* algorithms"
);
DecodingKey::from_ec_pem(cfg.public_key_pem.as_bytes())
.context("parsing auth.jwt.public_key_pem as EC")
}
EdDSA => {
anyhow::ensure!(
!cfg.public_key_pem.is_empty(),
"auth.jwt.public_key_pem (or jwks_url) is required for EdDSA"
);
DecodingKey::from_ed_pem(cfg.public_key_pem.as_bytes())
.context("parsing auth.jwt.public_key_pem as Ed25519")
}
}
}
fn parse_algorithm(s: &str) -> Result<Algorithm> {
Ok(match s.to_ascii_uppercase().as_str() {
"HS256" => Algorithm::HS256,
"HS384" => Algorithm::HS384,
"HS512" => Algorithm::HS512,
"RS256" => Algorithm::RS256,
"RS384" => Algorithm::RS384,
"RS512" => Algorithm::RS512,
"PS256" => Algorithm::PS256,
"PS384" => Algorithm::PS384,
"PS512" => Algorithm::PS512,
"ES256" => Algorithm::ES256,
"ES384" => Algorithm::ES384,
"EDDSA" => Algorithm::EdDSA,
other => anyhow::bail!("unsupported auth.jwt.algorithm: {other}"),
})
}
struct JwksCache {
url: String,
ttl: Duration,
http: reqwest::Client,
inner: RwLock<Option<CachedKeys>>,
}
struct CachedKeys {
fetched_at: Instant,
by_kid: HashMap<String, Arc<DecodingKey>>,
}
impl JwksCache {
fn new(url: String, ttl: Duration) -> Result<JwksCache> {
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.build()
.context("building JWKS HTTP client")?;
Ok(JwksCache {
url,
ttl,
http,
inner: RwLock::new(None),
})
}
async fn key_for(&self, kid: Option<&str>) -> Result<Arc<DecodingKey>> {
if let Some(key) = self.lookup_fresh(kid).await {
return Ok(key);
}
let mut guard = self.inner.write().await;
let needs_fetch = match guard.as_ref() {
Some(c) => c.fetched_at.elapsed() > self.ttl || select_key(&c.by_kid, kid).is_none(),
None => true,
};
if needs_fetch {
match self.fetch().await {
Ok(by_kid) => {
*guard = Some(CachedKeys {
fetched_at: Instant::now(),
by_kid,
});
}
Err(e) if guard.is_some() => {
warn!(error = %format!("{e:#}"), "JWKS refresh failed; using cached keys");
}
Err(e) => return Err(e.context("JWKS refresh failed and no cached keys")),
}
}
if let Some(c) = guard.as_ref() {
if let Some(key) = select_key(&c.by_kid, kid) {
return Ok(key);
}
}
match kid {
Some(k) => anyhow::bail!("no JWKS key for kid {k:?}"),
None => anyhow::bail!("JWKS contains no usable key"),
}
}
async fn lookup_fresh(&self, kid: Option<&str>) -> Option<Arc<DecodingKey>> {
let guard = self.inner.read().await;
let cached = guard.as_ref()?;
if cached.fetched_at.elapsed() > self.ttl {
return None;
}
select_key(&cached.by_kid, kid)
}
async fn fetch(&self) -> Result<HashMap<String, Arc<DecodingKey>>> {
let body = self
.http
.get(&self.url)
.send()
.await
.with_context(|| format!("fetching JWKS from {}", self.url))?
.error_for_status()
.context("JWKS endpoint returned an error status")?
.text()
.await
.context("reading JWKS body")?;
parse_jwks(&body)
}
}
fn select_key(
by_kid: &HashMap<String, Arc<DecodingKey>>,
kid: Option<&str>,
) -> Option<Arc<DecodingKey>> {
match kid {
Some(k) => by_kid.get(k).cloned(),
None if by_kid.len() == 1 => by_kid.values().next().cloned(),
None => by_kid.get("").cloned(),
}
}
fn parse_jwks(json: &str) -> Result<HashMap<String, Arc<DecodingKey>>> {
let set: JwkSet = serde_json::from_str(json).context("parsing JWKS JSON")?;
let mut by_kid = HashMap::new();
for jwk in &set.keys {
match DecodingKey::from_jwk(jwk) {
Ok(key) => {
let kid = jwk.common.key_id.clone().unwrap_or_default();
by_kid.insert(kid, Arc::new(key));
}
Err(e) => warn!(error = %e, "skipping unusable JWKS key"),
}
}
anyhow::ensure!(!by_kid.is_empty(), "JWKS contained no usable keys");
Ok(by_kid)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::AuthCfg;
use jsonwebtoken::{encode, EncodingKey, Header};
use serde_json::json;
use std::collections::BTreeMap;
fn headers_with(name: &'static str, value: &str) -> HeaderMap {
let mut h = HeaderMap::new();
h.insert(name, value.parse().unwrap());
h
}
fn basic_value(user: &str, pass: &str) -> String {
format!("Basic {}", B64.encode(format!("{user}:{pass}")))
}
fn cfg_with_user(user: &str, secret: &str) -> AuthCfg {
AuthCfg {
users: BTreeMap::from([(user.to_string(), secret.to_string())]),
..Default::default()
}
}
#[test]
fn basic_auth_plaintext_accepts_correct_rejects_bad() {
let cfg = cfg_with_user("admin", "s3cret");
assert!(check_basic_auth(
&cfg,
&headers_with("authorization", &basic_value("admin", "s3cret"))
));
assert!(!check_basic_auth(
&cfg,
&headers_with("authorization", &basic_value("admin", "wrong"))
));
assert!(!check_basic_auth(
&cfg,
&headers_with("authorization", &basic_value("ghost", "s3cret"))
));
}
#[test]
fn basic_auth_rejects_missing_and_malformed_headers() {
let cfg = cfg_with_user("admin", "s3cret");
assert!(!check_basic_auth(&cfg, &HeaderMap::new()));
assert!(!check_basic_auth(
&cfg,
&headers_with("authorization", "Bearer token")
));
assert!(!check_basic_auth(
&cfg,
&headers_with("authorization", "Basic !!!not-base64!!!")
));
}
#[test]
fn basic_auth_argon2_path() {
let phc = hash_password("hunter2").unwrap();
assert!(phc.starts_with("$argon2"), "{phc}");
let cfg = cfg_with_user("admin", &phc);
assert!(check_basic_auth(
&cfg,
&headers_with("authorization", &basic_value("admin", "hunter2"))
));
assert!(!check_basic_auth(
&cfg,
&headers_with("authorization", &basic_value("admin", "nope"))
));
}
#[test]
fn constant_time_eq_handles_differing_lengths() {
assert!(constant_time_eq(b"abc", b"abc"));
assert!(!constant_time_eq(b"abc", b"abd"));
assert!(!constant_time_eq(b"abc", b"abcd"));
assert!(!constant_time_eq(b"", b"x"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn api_key_accepts_via_bearer_and_header_rejects_unknown() {
let keys = vec!["sk_live_abc".to_string(), "sk_live_def".to_string()];
let header = HeaderName::from_static("x-api-key");
assert!(
verify_api_key(&keys, &header, &headers_with("x-api-key", "sk_live_abc")).is_some()
);
assert!(verify_api_key(
&keys,
&header,
&headers_with("authorization", "Bearer sk_live_def")
)
.is_some());
assert!(verify_api_key(&keys, &header, &headers_with("x-api-key", "nope")).is_none());
assert!(verify_api_key(&keys, &header, &HeaderMap::new()).is_none());
}
#[test]
fn api_key_principal_is_stable_and_not_the_raw_key() {
let keys = vec!["super-secret-key".to_string()];
let header = HeaderName::from_static("x-api-key");
let p1 = verify_api_key(
&keys,
&header,
&headers_with("x-api-key", "super-secret-key"),
);
let p2 = verify_api_key(
&keys,
&header,
&headers_with("x-api-key", "super-secret-key"),
);
assert_eq!(p1, p2);
assert!(!p1.unwrap().contains("super-secret-key"));
}
fn hs_validator(secret: &str) -> JwtValidator {
JwtValidator::build(&JwtCfg {
algorithm: "HS256".into(),
secret: secret.into(),
issuer: "edgeguard-test".into(),
..Default::default()
})
.unwrap()
}
fn hs_token(secret: &str, claims: Value) -> String {
encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap()
}
fn far_future() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ 3600
}
#[tokio::test]
async fn jwt_hs256_accepts_valid_and_returns_sub() {
let v = hs_validator("topsecret");
let token = hs_token(
"topsecret",
json!({ "sub": "user-42", "iss": "edgeguard-test", "exp": far_future() }),
);
let principal = v.verify(&token).await.unwrap();
assert_eq!(principal.as_deref(), Some("user-42"));
}
#[tokio::test]
async fn jwt_hs256_rejects_bad_signature_wrong_issuer_and_expired() {
let v = hs_validator("topsecret");
let forged = hs_token(
"WRONG",
json!({ "sub": "x", "iss": "edgeguard-test", "exp": far_future() }),
);
assert!(v.verify(&forged).await.is_err());
let wrong_iss = hs_token(
"topsecret",
json!({ "sub": "x", "iss": "someone-else", "exp": far_future() }),
);
assert!(v.verify(&wrong_iss).await.is_err());
let expired = hs_token(
"topsecret",
json!({ "sub": "x", "iss": "edgeguard-test", "exp": 1_000 }),
);
assert!(v.verify(&expired).await.is_err());
}
#[tokio::test]
async fn jwt_rejects_algorithm_confusion() {
let v = hs_validator("topsecret");
let mut header = Header::new(Algorithm::HS384);
header.kid = None;
let token = encode(
&header,
&json!({ "sub": "x", "iss": "edgeguard-test", "exp": far_future() }),
&EncodingKey::from_secret(b"topsecret"),
)
.unwrap();
assert!(v.verify(&token).await.is_err());
}
#[tokio::test]
async fn jwt_hs256_rejects_not_yet_valid_token() {
let v = hs_validator("topsecret");
let token = hs_token(
"topsecret",
json!({ "sub": "x", "iss": "edgeguard-test", "exp": far_future(), "nbf": far_future() }),
);
assert!(v.verify(&token).await.is_err());
}
#[test]
fn build_rejects_bad_algorithm_and_missing_secret() {
assert!(JwtValidator::build(&JwtCfg {
algorithm: "NOPE".into(),
..Default::default()
})
.is_err());
assert!(JwtValidator::build(&JwtCfg {
algorithm: "HS256".into(),
secret: String::new(),
..Default::default()
})
.is_err());
}
#[test]
fn parse_jwks_indexes_keys_by_kid() {
let jwks = json!({
"keys": [{
"kty": "RSA",
"kid": "key-1",
"use": "sig",
"alg": "RS256",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e": "AQAB"
}]
})
.to_string();
let keys = parse_jwks(&jwks).unwrap();
assert!(
keys.contains_key("key-1"),
"kid not indexed: {:?}",
keys.keys()
);
}
#[test]
fn parse_jwks_rejects_empty_and_garbage() {
assert!(parse_jwks("not json").is_err());
assert!(parse_jwks(r#"{"keys":[]}"#).is_err());
}
}