1use argon2::{
10 password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
11 Argon2,
12};
13use axum::{
14 body::Body,
15 http::{header, Request, StatusCode},
16 middleware::Next,
17 response::{IntoResponse, Response},
18};
19use std::collections::HashSet;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, Instant};
22
23#[cfg(feature = "jwt")]
29#[derive(Clone, Debug)]
30pub struct JwtConfig {
31 pub issuer: String,
33 pub audience: String,
35 pub public_key_pem: Option<String>,
37 pub allowed_algorithms: Vec<jsonwebtoken::Algorithm>,
39}
40
41#[cfg(feature = "jwt")]
42impl JwtConfig {
43 pub fn new(issuer: String, audience: String) -> Self {
45 Self {
46 issuer,
47 audience,
48 public_key_pem: None,
49 allowed_algorithms: vec![jsonwebtoken::Algorithm::RS256],
50 }
51 }
52
53 pub fn with_public_key(mut self, pem: String) -> Self {
55 self.public_key_pem = Some(pem);
56 self
57 }
58}
59
60#[cfg(feature = "jwt")]
62#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct TokenClaims {
64 pub sub: String,
66 #[serde(default)]
68 pub email: Option<String>,
69 #[serde(default)]
71 pub roles: Vec<String>,
72 #[serde(default)]
74 pub tenant_id: Option<String>,
75 pub exp: usize,
77 pub iss: String,
79 #[serde(default)]
81 pub aud: Option<serde_json::Value>,
82}
83
84#[cfg(feature = "jwt")]
86#[derive(Clone)]
87pub struct JwtValidator {
88 config: JwtConfig,
89 decoding_key: Option<jsonwebtoken::DecodingKey>,
90}
91
92#[cfg(feature = "jwt")]
93impl std::fmt::Debug for JwtValidator {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("JwtValidator")
96 .field("config", &self.config)
97 .field(
98 "decoding_key",
99 &self.decoding_key.as_ref().map(|_| "[redacted]"),
100 )
101 .finish()
102 }
103}
104
105#[cfg(feature = "jwt")]
106impl JwtValidator {
107 pub fn new(config: JwtConfig) -> Result<Self, String> {
109 let decoding_key = if let Some(ref pem) = config.public_key_pem {
110 Some(
111 jsonwebtoken::DecodingKey::from_rsa_pem(pem.as_bytes())
112 .map_err(|e| format!("Invalid RSA PEM key: {}", e))?,
113 )
114 } else {
115 None
116 };
117
118 Ok(Self {
119 config,
120 decoding_key,
121 })
122 }
123
124 pub fn validate_token(&self, token: &str) -> Result<TokenClaims, String> {
126 let decoding_key = self
127 .decoding_key
128 .as_ref()
129 .ok_or_else(|| "No decoding key configured".to_string())?;
130
131 let mut validation = jsonwebtoken::Validation::new(
132 *self
133 .config
134 .allowed_algorithms
135 .first()
136 .unwrap_or(&jsonwebtoken::Algorithm::RS256),
137 );
138 validation.set_issuer(&[&self.config.issuer]);
139 validation.set_audience(&[&self.config.audience]);
140 validation.validate_exp = true;
141
142 let token_data = jsonwebtoken::decode::<TokenClaims>(token, decoding_key, &validation)
143 .map_err(|e| format!("JWT validation failed: {}", e))?;
144
145 Ok(token_data.claims)
146 }
147}
148
149#[derive(Clone, Debug)]
155pub struct AuthConfig {
156 pub enabled: bool,
158 hashed_keys: Vec<String>,
160 pub exempt_paths: HashSet<String>,
162 cache: Arc<Mutex<Vec<CacheEntry>>>,
164 #[cfg(feature = "jwt")]
166 pub jwt_validator: Option<JwtValidator>,
167}
168
169#[derive(Clone, Debug)]
170struct CacheEntry {
171 key_hash: u64,
173 expires_at: Instant,
175}
176
177impl Default for AuthConfig {
178 fn default() -> Self {
179 Self {
180 enabled: false,
181 hashed_keys: Vec::new(),
182 exempt_paths: HashSet::from([
183 "/health".to_string(),
184 "/ready".to_string(),
185 "/live".to_string(),
186 "/metrics".to_string(),
187 ]),
188 cache: Arc::new(Mutex::new(Vec::new())),
189 #[cfg(feature = "jwt")]
190 jwt_validator: None,
191 }
192 }
193}
194
195impl AuthConfig {
196 pub fn with_api_keys(api_keys: Vec<String>) -> Self {
200 let argon2 = Argon2::default();
201 let hashed_keys: Vec<String> = api_keys
202 .iter()
203 .map(|key| {
204 let salt = SaltString::generate(&mut OsRng);
205 argon2
206 .hash_password(key.as_bytes(), &salt)
207 .expect("Argon2id hashing should not fail")
208 .to_string()
209 })
210 .collect();
211
212 Self {
213 enabled: true,
214 hashed_keys,
215 exempt_paths: HashSet::from([
216 "/health".to_string(),
217 "/ready".to_string(),
218 "/live".to_string(),
219 "/metrics".to_string(),
220 ]),
221 cache: Arc::new(Mutex::new(Vec::new())),
222 #[cfg(feature = "jwt")]
223 jwt_validator: None,
224 }
225 }
226
227 pub fn with_prehashed_keys(hashed_keys: Vec<String>) -> Self {
231 Self {
232 enabled: true,
233 hashed_keys,
234 exempt_paths: HashSet::from([
235 "/health".to_string(),
236 "/ready".to_string(),
237 "/live".to_string(),
238 "/metrics".to_string(),
239 ]),
240 cache: Arc::new(Mutex::new(Vec::new())),
241 #[cfg(feature = "jwt")]
242 jwt_validator: None,
243 }
244 }
245
246 #[cfg(feature = "jwt")]
248 pub fn with_jwt(mut self, config: JwtConfig) -> Result<Self, String> {
249 let validator = JwtValidator::new(config)?;
250 self.jwt_validator = Some(validator);
251 self.enabled = true;
252 Ok(self)
253 }
254
255 pub fn with_exempt_paths(mut self, paths: Vec<String>) -> Self {
257 for path in paths {
258 self.exempt_paths.insert(path);
259 }
260 self
261 }
262
263 fn verify_key(&self, submitted_key: &str) -> bool {
268 let key_hash = fast_hash(submitted_key);
269
270 {
272 let cache = self.cache.lock().unwrap_or_else(|e| e.into_inner());
273 let now = Instant::now();
274 for entry in cache.iter() {
275 if entry.key_hash == key_hash && entry.expires_at > now {
276 return true;
277 }
278 }
279 }
280
281 let argon2 = Argon2::default();
283 let mut any_match = false;
284
285 for stored_hash in &self.hashed_keys {
286 if let Ok(parsed_hash) = PasswordHash::new(stored_hash) {
287 if argon2
288 .verify_password(submitted_key.as_bytes(), &parsed_hash)
289 .is_ok()
290 {
291 any_match = true;
292 }
293 }
294 }
295
296 if any_match {
298 let mut cache = self.cache.lock().unwrap_or_else(|e| e.into_inner());
299 let now = Instant::now();
301 cache.retain(|e| e.expires_at > now);
302 cache.push(CacheEntry {
304 key_hash,
305 expires_at: now + Duration::from_secs(5),
306 });
307 }
308
309 any_match
310 }
311
312 fn verify_bearer(&self, token: &str) -> AuthResult {
314 #[cfg(feature = "jwt")]
316 if let Some(ref validator) = self.jwt_validator {
317 match validator.validate_token(token) {
318 Ok(_claims) => return AuthResult::Authenticated,
319 Err(_) => {
320 }
322 }
323 }
324
325 if self.verify_key(token) {
327 AuthResult::Authenticated
328 } else {
329 AuthResult::InvalidCredentials
330 }
331 }
332}
333
334enum AuthResult {
336 Authenticated,
337 InvalidCredentials,
338}
339
340fn fast_hash(s: &str) -> u64 {
342 let mut hash: u64 = 0xcbf29ce484222325;
344 for byte in s.bytes() {
345 hash ^= byte as u64;
346 hash = hash.wrapping_mul(0x100000001b3);
347 }
348 hash
349}
350
351pub async fn auth_middleware(
357 axum::Extension(config): axum::Extension<AuthConfig>,
358 request: Request<Body>,
359 next: Next,
360) -> Response {
361 if !config.enabled {
363 return next.run(request).await;
364 }
365
366 let path = request.uri().path();
368 if config.exempt_paths.contains(path) {
369 return next.run(request).await;
370 }
371
372 let bearer_token = extract_bearer_token(&request);
374 let api_key = extract_x_api_key(&request);
375
376 if let Some(token) = bearer_token {
378 return match config.verify_bearer(&token) {
379 AuthResult::Authenticated => next.run(request).await,
380 AuthResult::InvalidCredentials => (
381 StatusCode::UNAUTHORIZED,
382 [(header::WWW_AUTHENTICATE, "Bearer")],
383 "Invalid credentials",
384 )
385 .into_response(),
386 };
387 }
388
389 if let Some(key) = api_key {
391 if config.verify_key(&key) {
392 return next.run(request).await;
393 }
394 return (
395 StatusCode::UNAUTHORIZED,
396 [(header::WWW_AUTHENTICATE, "Bearer")],
397 "Invalid API key",
398 )
399 .into_response();
400 }
401
402 (
404 StatusCode::UNAUTHORIZED,
405 [(header::WWW_AUTHENTICATE, "Bearer")],
406 "API key required. Provide via 'Authorization: Bearer <key>' or 'X-API-Key' header",
407 )
408 .into_response()
409}
410
411fn extract_bearer_token(request: &Request<Body>) -> Option<String> {
413 request
414 .headers()
415 .get(header::AUTHORIZATION)
416 .and_then(|h| h.to_str().ok())
417 .and_then(|s| s.strip_prefix("Bearer "))
418 .map(std::string::ToString::to_string)
419}
420
421fn extract_x_api_key(request: &Request<Body>) -> Option<String> {
423 request
424 .headers()
425 .get("X-API-Key")
426 .and_then(|h| h.to_str().ok())
427 .map(std::string::ToString::to_string)
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433 use axum::{
434 body::Body,
435 http::{Request, StatusCode},
436 middleware,
437 routing::get,
438 Router,
439 };
440 use tower::ServiceExt;
441
442 async fn test_handler() -> &'static str {
443 "ok"
444 }
445
446 fn test_router(config: AuthConfig) -> Router {
447 Router::new()
448 .route("/api/test", get(test_handler))
449 .route("/health", get(test_handler))
450 .layer(middleware::from_fn(auth_middleware))
451 .layer(axum::Extension(config))
452 }
453
454 #[tokio::test]
455 async fn test_auth_disabled() {
456 let config = AuthConfig::default();
457 let router = test_router(config);
458
459 let request = Request::builder()
460 .uri("/api/test")
461 .body(Body::empty())
462 .unwrap();
463
464 let response = router.oneshot(request).await.unwrap();
465 assert_eq!(response.status(), StatusCode::OK);
466 }
467
468 #[tokio::test]
469 async fn test_valid_bearer_token() {
470 let config = AuthConfig::with_api_keys(vec!["test-key-123".to_string()]);
471 let router = test_router(config);
472
473 let request = Request::builder()
474 .uri("/api/test")
475 .header("Authorization", "Bearer test-key-123")
476 .body(Body::empty())
477 .unwrap();
478
479 let response = router.oneshot(request).await.unwrap();
480 assert_eq!(response.status(), StatusCode::OK);
481 }
482
483 #[tokio::test]
484 async fn test_valid_x_api_key() {
485 let config = AuthConfig::with_api_keys(vec!["test-key-456".to_string()]);
486 let router = test_router(config);
487
488 let request = Request::builder()
489 .uri("/api/test")
490 .header("X-API-Key", "test-key-456")
491 .body(Body::empty())
492 .unwrap();
493
494 let response = router.oneshot(request).await.unwrap();
495 assert_eq!(response.status(), StatusCode::OK);
496 }
497
498 #[tokio::test]
499 async fn test_invalid_api_key() {
500 let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
501 let router = test_router(config);
502
503 let request = Request::builder()
504 .uri("/api/test")
505 .header("Authorization", "Bearer wrong-key")
506 .body(Body::empty())
507 .unwrap();
508
509 let response = router.oneshot(request).await.unwrap();
510 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
511 }
512
513 #[tokio::test]
514 async fn test_missing_api_key() {
515 let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
516 let router = test_router(config);
517
518 let request = Request::builder()
519 .uri("/api/test")
520 .body(Body::empty())
521 .unwrap();
522
523 let response = router.oneshot(request).await.unwrap();
524 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
525 }
526
527 #[tokio::test]
528 async fn test_exempt_path() {
529 let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
530 let router = test_router(config);
531
532 let request = Request::builder()
533 .uri("/health")
534 .body(Body::empty())
535 .unwrap();
536
537 let response = router.oneshot(request).await.unwrap();
538 assert_eq!(response.status(), StatusCode::OK);
539 }
540
541 #[tokio::test]
542 async fn test_prehashed_keys() {
543 let argon2 = Argon2::default();
545 let salt = SaltString::generate(&mut OsRng);
546 let hash = argon2
547 .hash_password(b"pre-hashed-key", &salt)
548 .unwrap()
549 .to_string();
550
551 let config = AuthConfig::with_prehashed_keys(vec![hash]);
552 let router = test_router(config);
553
554 let request = Request::builder()
555 .uri("/api/test")
556 .header("Authorization", "Bearer pre-hashed-key")
557 .body(Body::empty())
558 .unwrap();
559
560 let response = router.oneshot(request).await.unwrap();
561 assert_eq!(response.status(), StatusCode::OK);
562 }
563
564 #[tokio::test]
565 async fn test_cache_hit() {
566 let config = AuthConfig::with_api_keys(vec!["cached-key".to_string()]);
567
568 let router1 = test_router(config.clone());
570 let request1 = Request::builder()
571 .uri("/api/test")
572 .header("Authorization", "Bearer cached-key")
573 .body(Body::empty())
574 .unwrap();
575 let response1 = router1.oneshot(request1).await.unwrap();
576 assert_eq!(response1.status(), StatusCode::OK);
577
578 let router2 = test_router(config);
580 let request2 = Request::builder()
581 .uri("/api/test")
582 .header("Authorization", "Bearer cached-key")
583 .body(Body::empty())
584 .unwrap();
585 let response2 = router2.oneshot(request2).await.unwrap();
586 assert_eq!(response2.status(), StatusCode::OK);
587 }
588
589 #[tokio::test]
590 async fn test_api_key_fallback_still_works() {
591 let config = AuthConfig::with_api_keys(vec!["my-key".to_string()]);
593 let router = test_router(config);
594
595 let request = Request::builder()
596 .uri("/api/test")
597 .header("Authorization", "Bearer my-key")
598 .body(Body::empty())
599 .unwrap();
600
601 let response = router.oneshot(request).await.unwrap();
602 assert_eq!(response.status(), StatusCode::OK);
603 }
604
605 #[cfg(feature = "jwt")]
606 mod jwt_tests {
607 use super::*;
608
609 #[test]
610 fn test_jwt_config_creation() {
611 let config =
612 JwtConfig::new("https://auth.example.com".to_string(), "my-api".to_string());
613 assert_eq!(config.issuer, "https://auth.example.com");
614 assert_eq!(config.audience, "my-api");
615 assert!(config.public_key_pem.is_none());
616 assert_eq!(
617 config.allowed_algorithms,
618 vec![jsonwebtoken::Algorithm::RS256]
619 );
620 }
621
622 #[test]
623 fn test_jwt_validator_requires_key() {
624 let config = JwtConfig::new("issuer".to_string(), "audience".to_string());
625 let validator = JwtValidator::new(config).expect("should create");
626 let result = validator.validate_token("some.invalid.token");
627 assert!(result.is_err());
628 }
629
630 #[test]
631 fn test_token_claims_deserialization() {
632 let json = r#"{
633 "sub": "user123",
634 "email": "user@example.com",
635 "roles": ["admin", "operator"],
636 "tenant_id": "tenant1",
637 "exp": 9999999999,
638 "iss": "https://auth.example.com"
639 }"#;
640 let claims: TokenClaims = serde_json::from_str(json).unwrap();
641 assert_eq!(claims.sub, "user123");
642 assert_eq!(claims.email, Some("user@example.com".to_string()));
643 assert_eq!(claims.roles, vec!["admin", "operator"]);
644 assert_eq!(claims.tenant_id, Some("tenant1".to_string()));
645 }
646 }
647}