1use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13
14use anyhow::{Context, Result};
15use argon2::{Argon2, PasswordHash, PasswordVerifier};
16use axum::http::{header, HeaderMap, HeaderName};
17use base64::{engine::general_purpose::STANDARD as B64, Engine};
18use jsonwebtoken::jwk::JwkSet;
19use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
20use serde_json::Value;
21use tokio::sync::RwLock;
22use tracing::warn;
23
24use crate::config::{AuthCfg, JwtCfg};
25
26pub enum Decision {
28 Allow(Option<String>),
31 Deny(Challenge),
33}
34
35pub enum Challenge {
37 Basic(String),
38 Bearer,
39 None,
41}
42
43pub enum AuthEngine {
46 Open,
48 Basic,
49 ApiKey {
50 keys: Vec<String>,
51 header: HeaderName,
52 },
53 Jwt(Box<JwtValidator>),
54}
55
56impl AuthEngine {
57 pub fn build(cfg: &AuthCfg) -> Result<AuthEngine> {
61 match cfg.mode.as_str() {
62 "none" => Ok(AuthEngine::Open),
63 "basic" => Ok(AuthEngine::Basic),
64 "apikey" => {
65 let header = HeaderName::from_bytes(cfg.api_key_header.as_bytes())
66 .context("invalid auth.api_key_header")?;
67 Ok(AuthEngine::ApiKey {
68 keys: cfg.api_keys.clone(),
69 header,
70 })
71 }
72 "jwt" => Ok(AuthEngine::Jwt(Box::new(JwtValidator::build(&cfg.jwt)?))),
73 other => anyhow::bail!("unknown auth.mode: {other:?} (expected none|basic|apikey|jwt)"),
74 }
75 }
76
77 pub async fn authorize(&self, cfg: &AuthCfg, headers: &HeaderMap) -> Decision {
79 match self {
80 AuthEngine::Open => Decision::Allow(None),
81 AuthEngine::Basic => {
82 if check_basic_auth(cfg, headers) {
83 Decision::Allow(basic_username(headers))
85 } else {
86 Decision::Deny(Challenge::Basic(format!("Basic realm=\"{}\"", cfg.realm)))
87 }
88 }
89 AuthEngine::ApiKey { keys, header } => match verify_api_key(keys, header, headers) {
90 Some(principal) => Decision::Allow(Some(principal)),
91 None => Decision::Deny(Challenge::None),
92 },
93 AuthEngine::Jwt(v) => match bearer_token(headers) {
94 Some(token) => match v.verify(token).await {
95 Ok(principal) => Decision::Allow(principal),
96 Err(_) => Decision::Deny(Challenge::Bearer),
97 },
98 None => Decision::Deny(Challenge::Bearer),
99 },
100 }
101 }
102}
103
104pub fn check_basic_auth(cfg: &AuthCfg, headers: &HeaderMap) -> bool {
107 let Some((user, pass)) = basic_credentials(headers) else {
108 return false;
109 };
110 let Some(stored) = cfg.users.get(&user) else {
111 return false;
112 };
113 if stored.starts_with("$argon2") {
114 match PasswordHash::new(stored) {
115 Ok(parsed) => Argon2::default()
116 .verify_password(pass.as_bytes(), &parsed)
117 .is_ok(),
118 Err(_) => false,
119 }
120 } else {
121 constant_time_eq(stored.as_bytes(), pass.as_bytes())
123 }
124}
125
126fn basic_credentials(headers: &HeaderMap) -> Option<(String, String)> {
128 let auth = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
129 let b64 = auth.strip_prefix("Basic ")?;
130 let decoded = B64.decode(b64.trim()).ok()?;
131 let creds = String::from_utf8(decoded).ok()?;
132 let (user, pass) = creds.split_once(':')?;
133 Some((user.to_string(), pass.to_string()))
134}
135
136fn basic_username(headers: &HeaderMap) -> Option<String> {
137 basic_credentials(headers).map(|(u, _)| u)
138}
139
140fn bearer_token(headers: &HeaderMap) -> Option<&str> {
142 headers
143 .get(header::AUTHORIZATION)?
144 .to_str()
145 .ok()?
146 .strip_prefix("Bearer ")
147 .map(str::trim)
148}
149
150pub fn verify_api_key(keys: &[String], header: &HeaderName, headers: &HeaderMap) -> Option<String> {
156 let presented = headers
157 .get(header)
158 .and_then(|v| v.to_str().ok())
159 .map(str::trim)
160 .or_else(|| bearer_token(headers))?;
161
162 let mut matched: Option<&String> = None;
163 for key in keys {
164 if constant_time_eq(key.as_bytes(), presented.as_bytes()) {
166 matched = Some(key);
167 }
168 }
169 matched.map(|k| format!("apikey:{}", short_id(k)))
170}
171
172fn short_id(secret: &str) -> String {
175 use std::hash::{Hash, Hasher};
176 let mut h = std::collections::hash_map::DefaultHasher::new();
177 secret.hash(&mut h);
178 format!("{:016x}", h.finish())
179}
180
181pub fn hash_password(password: &str) -> Result<String> {
184 use argon2::password_hash::rand_core::OsRng;
185 use argon2::password_hash::{PasswordHasher, SaltString};
186
187 let salt = SaltString::generate(&mut OsRng);
188 Argon2::default()
189 .hash_password(password.as_bytes(), &salt)
190 .map(|h| h.to_string())
191 .map_err(|e| anyhow::anyhow!("hashing password: {e}"))
192}
193
194fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
195 let mut diff = a.len() ^ b.len();
200 let max_len = a.len().max(b.len());
201 for i in 0..max_len {
202 let x = a.get(i).copied().unwrap_or(0);
203 let y = b.get(i).copied().unwrap_or(0);
204 diff |= usize::from(x ^ y);
205 }
206 diff == 0
207}
208
209type Principal = Option<String>;
215
216pub struct JwtValidator {
221 alg: Algorithm,
222 validation: Validation,
223 keys: KeySource,
224}
225
226enum KeySource {
227 Static(Arc<DecodingKey>),
229 Jwks(JwksCache),
231}
232
233impl JwtValidator {
234 pub fn build(cfg: &JwtCfg) -> Result<JwtValidator> {
235 let alg = parse_algorithm(&cfg.algorithm)?;
236
237 let mut validation = Validation::new(alg);
238 validation.leeway = cfg.leeway_secs;
239 validation.validate_nbf = true;
242 if !cfg.issuer.is_empty() {
243 validation.set_issuer(std::slice::from_ref(&cfg.issuer));
244 }
245 if cfg.audience.is_empty() {
246 validation.validate_aud = false;
248 } else {
249 validation.set_audience(std::slice::from_ref(&cfg.audience));
250 }
251
252 let keys = if !cfg.jwks_url.is_empty() {
253 KeySource::Jwks(JwksCache::new(
254 cfg.jwks_url.clone(),
255 Duration::from_secs(cfg.jwks_cache_secs),
256 )?)
257 } else {
258 KeySource::Static(Arc::new(static_key(cfg, alg)?))
259 };
260
261 Ok(JwtValidator {
262 alg,
263 validation,
264 keys,
265 })
266 }
267
268 pub async fn verify(&self, token: &str) -> Result<Principal> {
270 let header = decode_header(token).context("malformed JWT header")?;
271 anyhow::ensure!(
273 header.alg == self.alg,
274 "token alg {:?} != configured {:?}",
275 header.alg,
276 self.alg
277 );
278
279 let key = match &self.keys {
280 KeySource::Static(k) => k.clone(),
281 KeySource::Jwks(cache) => cache.key_for(header.kid.as_deref()).await?,
282 };
283
284 let data = decode::<Value>(token, &key, &self.validation).context("JWT rejected")?;
285 let principal = data
286 .claims
287 .get("sub")
288 .and_then(|v| v.as_str())
289 .map(|s| s.to_string());
290 Ok(principal)
291 }
292}
293
294fn static_key(cfg: &JwtCfg, alg: Algorithm) -> Result<DecodingKey> {
296 use Algorithm::*;
297 match alg {
298 HS256 | HS384 | HS512 => {
299 anyhow::ensure!(
300 !cfg.secret.is_empty(),
301 "auth.jwt.secret (or $EDGEGUARD_JWT_SECRET) is required for HS* algorithms"
302 );
303 Ok(DecodingKey::from_secret(cfg.secret.as_bytes()))
304 }
305 RS256 | RS384 | RS512 | PS256 | PS384 | PS512 => {
306 anyhow::ensure!(
307 !cfg.public_key_pem.is_empty(),
308 "auth.jwt.public_key_pem (or jwks_url) is required for RS*/PS* algorithms"
309 );
310 DecodingKey::from_rsa_pem(cfg.public_key_pem.as_bytes())
311 .context("parsing auth.jwt.public_key_pem as RSA")
312 }
313 ES256 | ES384 => {
314 anyhow::ensure!(
315 !cfg.public_key_pem.is_empty(),
316 "auth.jwt.public_key_pem (or jwks_url) is required for ES* algorithms"
317 );
318 DecodingKey::from_ec_pem(cfg.public_key_pem.as_bytes())
319 .context("parsing auth.jwt.public_key_pem as EC")
320 }
321 EdDSA => {
322 anyhow::ensure!(
323 !cfg.public_key_pem.is_empty(),
324 "auth.jwt.public_key_pem (or jwks_url) is required for EdDSA"
325 );
326 DecodingKey::from_ed_pem(cfg.public_key_pem.as_bytes())
327 .context("parsing auth.jwt.public_key_pem as Ed25519")
328 }
329 }
330}
331
332fn parse_algorithm(s: &str) -> Result<Algorithm> {
333 Ok(match s.to_ascii_uppercase().as_str() {
334 "HS256" => Algorithm::HS256,
335 "HS384" => Algorithm::HS384,
336 "HS512" => Algorithm::HS512,
337 "RS256" => Algorithm::RS256,
338 "RS384" => Algorithm::RS384,
339 "RS512" => Algorithm::RS512,
340 "PS256" => Algorithm::PS256,
341 "PS384" => Algorithm::PS384,
342 "PS512" => Algorithm::PS512,
343 "ES256" => Algorithm::ES256,
344 "ES384" => Algorithm::ES384,
345 "EDDSA" => Algorithm::EdDSA,
346 other => anyhow::bail!("unsupported auth.jwt.algorithm: {other}"),
347 })
348}
349
350struct JwksCache {
353 url: String,
354 ttl: Duration,
355 http: reqwest::Client,
356 inner: RwLock<Option<CachedKeys>>,
357}
358
359struct CachedKeys {
360 fetched_at: Instant,
361 by_kid: HashMap<String, Arc<DecodingKey>>,
363}
364
365impl JwksCache {
366 fn new(url: String, ttl: Duration) -> Result<JwksCache> {
367 let http = reqwest::Client::builder()
368 .timeout(Duration::from_secs(5))
369 .build()
370 .context("building JWKS HTTP client")?;
371 Ok(JwksCache {
372 url,
373 ttl,
374 http,
375 inner: RwLock::new(None),
376 })
377 }
378
379 async fn key_for(&self, kid: Option<&str>) -> Result<Arc<DecodingKey>> {
382 if let Some(key) = self.lookup_fresh(kid).await {
383 return Ok(key);
384 }
385
386 let mut guard = self.inner.write().await;
391 let needs_fetch = match guard.as_ref() {
392 Some(c) => c.fetched_at.elapsed() > self.ttl || select_key(&c.by_kid, kid).is_none(),
393 None => true,
394 };
395 if needs_fetch {
396 match self.fetch().await {
397 Ok(by_kid) => {
398 *guard = Some(CachedKeys {
399 fetched_at: Instant::now(),
400 by_kid,
401 });
402 }
403 Err(e) if guard.is_some() => {
404 warn!(error = %format!("{e:#}"), "JWKS refresh failed; using cached keys");
405 }
406 Err(e) => return Err(e.context("JWKS refresh failed and no cached keys")),
407 }
408 }
409 if let Some(c) = guard.as_ref() {
410 if let Some(key) = select_key(&c.by_kid, kid) {
411 return Ok(key);
412 }
413 }
414 match kid {
415 Some(k) => anyhow::bail!("no JWKS key for kid {k:?}"),
416 None => anyhow::bail!("JWKS contains no usable key"),
417 }
418 }
419
420 async fn lookup_fresh(&self, kid: Option<&str>) -> Option<Arc<DecodingKey>> {
422 let guard = self.inner.read().await;
423 let cached = guard.as_ref()?;
424 if cached.fetched_at.elapsed() > self.ttl {
425 return None;
426 }
427 select_key(&cached.by_kid, kid)
428 }
429
430 async fn fetch(&self) -> Result<HashMap<String, Arc<DecodingKey>>> {
433 let body = self
434 .http
435 .get(&self.url)
436 .send()
437 .await
438 .with_context(|| format!("fetching JWKS from {}", self.url))?
439 .error_for_status()
440 .context("JWKS endpoint returned an error status")?
441 .text()
442 .await
443 .context("reading JWKS body")?;
444 parse_jwks(&body)
445 }
446}
447
448fn select_key(
451 by_kid: &HashMap<String, Arc<DecodingKey>>,
452 kid: Option<&str>,
453) -> Option<Arc<DecodingKey>> {
454 match kid {
455 Some(k) => by_kid.get(k).cloned(),
456 None if by_kid.len() == 1 => by_kid.values().next().cloned(),
457 None => by_kid.get("").cloned(),
458 }
459}
460
461fn parse_jwks(json: &str) -> Result<HashMap<String, Arc<DecodingKey>>> {
463 let set: JwkSet = serde_json::from_str(json).context("parsing JWKS JSON")?;
464 let mut by_kid = HashMap::new();
465 for jwk in &set.keys {
466 match DecodingKey::from_jwk(jwk) {
467 Ok(key) => {
468 let kid = jwk.common.key_id.clone().unwrap_or_default();
469 by_kid.insert(kid, Arc::new(key));
470 }
471 Err(e) => warn!(error = %e, "skipping unusable JWKS key"),
472 }
473 }
474 anyhow::ensure!(!by_kid.is_empty(), "JWKS contained no usable keys");
475 Ok(by_kid)
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use crate::config::AuthCfg;
482 use jsonwebtoken::{encode, EncodingKey, Header};
483 use serde_json::json;
484 use std::collections::BTreeMap;
485
486 fn headers_with(name: &'static str, value: &str) -> HeaderMap {
487 let mut h = HeaderMap::new();
488 h.insert(name, value.parse().unwrap());
489 h
490 }
491
492 fn basic_value(user: &str, pass: &str) -> String {
493 format!("Basic {}", B64.encode(format!("{user}:{pass}")))
494 }
495
496 fn cfg_with_user(user: &str, secret: &str) -> AuthCfg {
497 AuthCfg {
498 users: BTreeMap::from([(user.to_string(), secret.to_string())]),
499 ..Default::default()
500 }
501 }
502
503 #[test]
506 fn basic_auth_plaintext_accepts_correct_rejects_bad() {
507 let cfg = cfg_with_user("admin", "s3cret");
508 assert!(check_basic_auth(
509 &cfg,
510 &headers_with("authorization", &basic_value("admin", "s3cret"))
511 ));
512 assert!(!check_basic_auth(
513 &cfg,
514 &headers_with("authorization", &basic_value("admin", "wrong"))
515 ));
516 assert!(!check_basic_auth(
517 &cfg,
518 &headers_with("authorization", &basic_value("ghost", "s3cret"))
519 ));
520 }
521
522 #[test]
523 fn basic_auth_rejects_missing_and_malformed_headers() {
524 let cfg = cfg_with_user("admin", "s3cret");
525 assert!(!check_basic_auth(&cfg, &HeaderMap::new()));
526 assert!(!check_basic_auth(
527 &cfg,
528 &headers_with("authorization", "Bearer token")
529 ));
530 assert!(!check_basic_auth(
531 &cfg,
532 &headers_with("authorization", "Basic !!!not-base64!!!")
533 ));
534 }
535
536 #[test]
537 fn basic_auth_argon2_path() {
538 let phc = hash_password("hunter2").unwrap();
539 assert!(phc.starts_with("$argon2"), "{phc}");
540 let cfg = cfg_with_user("admin", &phc);
541 assert!(check_basic_auth(
542 &cfg,
543 &headers_with("authorization", &basic_value("admin", "hunter2"))
544 ));
545 assert!(!check_basic_auth(
546 &cfg,
547 &headers_with("authorization", &basic_value("admin", "nope"))
548 ));
549 }
550
551 #[test]
552 fn constant_time_eq_handles_differing_lengths() {
553 assert!(constant_time_eq(b"abc", b"abc"));
554 assert!(!constant_time_eq(b"abc", b"abd"));
555 assert!(!constant_time_eq(b"abc", b"abcd"));
557 assert!(!constant_time_eq(b"", b"x"));
558 assert!(constant_time_eq(b"", b""));
559 }
560
561 #[test]
564 fn api_key_accepts_via_bearer_and_header_rejects_unknown() {
565 let keys = vec!["sk_live_abc".to_string(), "sk_live_def".to_string()];
566 let header = HeaderName::from_static("x-api-key");
567
568 assert!(
570 verify_api_key(&keys, &header, &headers_with("x-api-key", "sk_live_abc")).is_some()
571 );
572 assert!(verify_api_key(
574 &keys,
575 &header,
576 &headers_with("authorization", "Bearer sk_live_def")
577 )
578 .is_some());
579 assert!(verify_api_key(&keys, &header, &headers_with("x-api-key", "nope")).is_none());
581 assert!(verify_api_key(&keys, &header, &HeaderMap::new()).is_none());
582 }
583
584 #[test]
585 fn api_key_principal_is_stable_and_not_the_raw_key() {
586 let keys = vec!["super-secret-key".to_string()];
587 let header = HeaderName::from_static("x-api-key");
588 let p1 = verify_api_key(
589 &keys,
590 &header,
591 &headers_with("x-api-key", "super-secret-key"),
592 );
593 let p2 = verify_api_key(
594 &keys,
595 &header,
596 &headers_with("x-api-key", "super-secret-key"),
597 );
598 assert_eq!(p1, p2);
599 assert!(!p1.unwrap().contains("super-secret-key"));
600 }
601
602 fn hs_validator(secret: &str) -> JwtValidator {
605 JwtValidator::build(&JwtCfg {
606 algorithm: "HS256".into(),
607 secret: secret.into(),
608 issuer: "edgeguard-test".into(),
609 ..Default::default()
610 })
611 .unwrap()
612 }
613
614 fn hs_token(secret: &str, claims: Value) -> String {
615 encode(
616 &Header::new(Algorithm::HS256),
617 &claims,
618 &EncodingKey::from_secret(secret.as_bytes()),
619 )
620 .unwrap()
621 }
622
623 fn far_future() -> u64 {
624 use std::time::{SystemTime, UNIX_EPOCH};
625 SystemTime::now()
626 .duration_since(UNIX_EPOCH)
627 .unwrap()
628 .as_secs()
629 + 3600
630 }
631
632 #[tokio::test]
633 async fn jwt_hs256_accepts_valid_and_returns_sub() {
634 let v = hs_validator("topsecret");
635 let token = hs_token(
636 "topsecret",
637 json!({ "sub": "user-42", "iss": "edgeguard-test", "exp": far_future() }),
638 );
639 let principal = v.verify(&token).await.unwrap();
640 assert_eq!(principal.as_deref(), Some("user-42"));
641 }
642
643 #[tokio::test]
644 async fn jwt_hs256_rejects_bad_signature_wrong_issuer_and_expired() {
645 let v = hs_validator("topsecret");
646
647 let forged = hs_token(
649 "WRONG",
650 json!({ "sub": "x", "iss": "edgeguard-test", "exp": far_future() }),
651 );
652 assert!(v.verify(&forged).await.is_err());
653
654 let wrong_iss = hs_token(
656 "topsecret",
657 json!({ "sub": "x", "iss": "someone-else", "exp": far_future() }),
658 );
659 assert!(v.verify(&wrong_iss).await.is_err());
660
661 let expired = hs_token(
663 "topsecret",
664 json!({ "sub": "x", "iss": "edgeguard-test", "exp": 1_000 }),
665 );
666 assert!(v.verify(&expired).await.is_err());
667 }
668
669 #[tokio::test]
670 async fn jwt_rejects_algorithm_confusion() {
671 let v = hs_validator("topsecret");
674 let mut header = Header::new(Algorithm::HS384);
675 header.kid = None;
676 let token = encode(
677 &header,
678 &json!({ "sub": "x", "iss": "edgeguard-test", "exp": far_future() }),
679 &EncodingKey::from_secret(b"topsecret"),
680 )
681 .unwrap();
682 assert!(v.verify(&token).await.is_err());
683 }
684
685 #[tokio::test]
686 async fn jwt_hs256_rejects_not_yet_valid_token() {
687 let v = hs_validator("topsecret");
688 let token = hs_token(
690 "topsecret",
691 json!({ "sub": "x", "iss": "edgeguard-test", "exp": far_future(), "nbf": far_future() }),
692 );
693 assert!(v.verify(&token).await.is_err());
694 }
695
696 #[test]
697 fn build_rejects_bad_algorithm_and_missing_secret() {
698 assert!(JwtValidator::build(&JwtCfg {
699 algorithm: "NOPE".into(),
700 ..Default::default()
701 })
702 .is_err());
703 assert!(JwtValidator::build(&JwtCfg {
705 algorithm: "HS256".into(),
706 secret: String::new(),
707 ..Default::default()
708 })
709 .is_err());
710 }
711
712 #[test]
713 fn parse_jwks_indexes_keys_by_kid() {
714 let jwks = json!({
716 "keys": [{
717 "kty": "RSA",
718 "kid": "key-1",
719 "use": "sig",
720 "alg": "RS256",
721 "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
722 "e": "AQAB"
723 }]
724 })
725 .to_string();
726 let keys = parse_jwks(&jwks).unwrap();
727 assert!(
728 keys.contains_key("key-1"),
729 "kid not indexed: {:?}",
730 keys.keys()
731 );
732 }
733
734 #[test]
735 fn parse_jwks_rejects_empty_and_garbage() {
736 assert!(parse_jwks("not json").is_err());
737 assert!(parse_jwks(r#"{"keys":[]}"#).is_err());
738 }
739}