1use argon2::{
10 password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
11 Argon2,
12};
13use axum::{
14 body::Body,
15 http::{header, Request, StatusCode},
16 middleware::Next,
17 response::{IntoResponse, Response},
18};
19use std::collections::HashSet;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, Instant};
22
23#[cfg(feature = "jwt")]
29#[derive(Clone, Debug)]
30pub struct JwtConfig {
31 pub issuer: String,
33 pub audience: String,
35 pub public_key_pem: Option<String>,
37 pub allowed_algorithms: Vec<jsonwebtoken::Algorithm>,
39}
40
41#[cfg(feature = "jwt")]
42impl JwtConfig {
43 pub fn new(issuer: String, audience: String) -> Self {
45 Self {
46 issuer,
47 audience,
48 public_key_pem: None,
49 allowed_algorithms: vec![jsonwebtoken::Algorithm::RS256],
50 }
51 }
52
53 pub fn with_public_key(mut self, pem: String) -> Self {
55 self.public_key_pem = Some(pem);
56 self
57 }
58}
59
60#[cfg(feature = "jwt")]
62#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct TokenClaims {
64 pub sub: String,
66 #[serde(default)]
68 pub email: Option<String>,
69 #[serde(default)]
71 pub roles: Vec<String>,
72 #[serde(default)]
74 pub tenant_id: Option<String>,
75 pub exp: usize,
77 pub iss: String,
79 #[serde(default)]
81 pub aud: Option<serde_json::Value>,
82}
83
84#[cfg(feature = "jwt")]
86#[derive(Clone)]
87pub struct JwtValidator {
88 config: JwtConfig,
89 decoding_key: Option<jsonwebtoken::DecodingKey>,
90}
91
92#[cfg(feature = "jwt")]
93impl std::fmt::Debug for JwtValidator {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("JwtValidator")
96 .field("config", &self.config)
97 .field(
98 "decoding_key",
99 &self.decoding_key.as_ref().map(|_| "[redacted]"),
100 )
101 .finish()
102 }
103}
104
105#[cfg(feature = "jwt")]
106impl JwtValidator {
107 pub fn new(config: JwtConfig) -> Result<Self, String> {
109 let decoding_key = if let Some(ref pem) = config.public_key_pem {
110 Some(
111 jsonwebtoken::DecodingKey::from_rsa_pem(pem.as_bytes())
112 .map_err(|e| format!("Invalid RSA PEM key: {}", e))?,
113 )
114 } else {
115 None
116 };
117
118 Ok(Self {
119 config,
120 decoding_key,
121 })
122 }
123
124 pub fn validate_token(&self, token: &str) -> Result<TokenClaims, String> {
126 let decoding_key = self
127 .decoding_key
128 .as_ref()
129 .ok_or_else(|| "No decoding key configured".to_string())?;
130
131 let mut validation = jsonwebtoken::Validation::new(
132 *self
133 .config
134 .allowed_algorithms
135 .first()
136 .unwrap_or(&jsonwebtoken::Algorithm::RS256),
137 );
138 validation.set_issuer(&[&self.config.issuer]);
139 validation.set_audience(&[&self.config.audience]);
140 validation.validate_exp = true;
141
142 let token_data = jsonwebtoken::decode::<TokenClaims>(token, decoding_key, &validation)
143 .map_err(|e| format!("JWT validation failed: {}", e))?;
144
145 Ok(token_data.claims)
146 }
147}
148
149#[derive(Clone, Debug)]
155pub struct AuthConfig {
156 pub enabled: bool,
158 hashed_keys: Vec<String>,
160 pub exempt_paths: HashSet<String>,
162 cache: Arc<Mutex<Vec<CacheEntry>>>,
164 #[cfg(feature = "jwt")]
166 pub jwt_validator: Option<JwtValidator>,
167}
168
169#[derive(Clone, Debug)]
170struct CacheEntry {
171 key_hash: u64,
173 expires_at: Instant,
175}
176
177impl Default for AuthConfig {
178 fn default() -> Self {
179 Self {
180 enabled: false,
181 hashed_keys: Vec::new(),
182 exempt_paths: HashSet::from([
183 "/health".to_string(),
184 "/ready".to_string(),
185 "/live".to_string(),
186 "/metrics".to_string(),
187 ]),
188 cache: Arc::new(Mutex::new(Vec::new())),
189 #[cfg(feature = "jwt")]
190 jwt_validator: None,
191 }
192 }
193}
194
195impl AuthConfig {
196 pub fn with_api_keys(api_keys: Vec<String>) -> Self {
200 let argon2 = Argon2::default();
201 let hashed_keys: Vec<String> = api_keys
202 .iter()
203 .map(|key| {
204 let salt = SaltString::generate(&mut OsRng);
205 argon2
206 .hash_password(key.as_bytes(), &salt)
207 .expect("Argon2id hashing should not fail")
208 .to_string()
209 })
210 .collect();
211
212 Self {
213 enabled: true,
214 hashed_keys,
215 exempt_paths: HashSet::from([
216 "/health".to_string(),
217 "/ready".to_string(),
218 "/live".to_string(),
219 "/metrics".to_string(),
220 ]),
221 cache: Arc::new(Mutex::new(Vec::new())),
222 #[cfg(feature = "jwt")]
223 jwt_validator: None,
224 }
225 }
226
227 pub fn with_prehashed_keys(hashed_keys: Vec<String>) -> Self {
231 Self {
232 enabled: true,
233 hashed_keys,
234 exempt_paths: HashSet::from([
235 "/health".to_string(),
236 "/ready".to_string(),
237 "/live".to_string(),
238 "/metrics".to_string(),
239 ]),
240 cache: Arc::new(Mutex::new(Vec::new())),
241 #[cfg(feature = "jwt")]
242 jwt_validator: None,
243 }
244 }
245
246 #[cfg(feature = "jwt")]
248 pub fn with_jwt(mut self, config: JwtConfig) -> Result<Self, String> {
249 let validator = JwtValidator::new(config)?;
250 self.jwt_validator = Some(validator);
251 self.enabled = true;
252 Ok(self)
253 }
254
255 pub fn with_exempt_paths(mut self, paths: Vec<String>) -> Self {
257 for path in paths {
258 self.exempt_paths.insert(path);
259 }
260 self
261 }
262
263 fn verify_key(&self, submitted_key: &str) -> bool {
268 let key_hash = fast_hash(submitted_key);
269
270 {
272 let cache = self.cache.lock().unwrap();
273 let now = Instant::now();
274 for entry in cache.iter() {
275 if entry.key_hash == key_hash && entry.expires_at > now {
276 return true;
277 }
278 }
279 }
280
281 let argon2 = Argon2::default();
283 let mut any_match = false;
284
285 for stored_hash in &self.hashed_keys {
286 if let Ok(parsed_hash) = PasswordHash::new(stored_hash) {
287 if argon2
288 .verify_password(submitted_key.as_bytes(), &parsed_hash)
289 .is_ok()
290 {
291 any_match = true;
292 }
293 }
294 }
295
296 if any_match {
298 let mut cache = self.cache.lock().unwrap();
299 let now = Instant::now();
301 cache.retain(|e| e.expires_at > now);
302 cache.push(CacheEntry {
304 key_hash,
305 expires_at: now + Duration::from_secs(5),
306 });
307 }
308
309 any_match
310 }
311
312 fn verify_bearer(&self, token: &str) -> AuthResult {
314 #[cfg(feature = "jwt")]
316 if let Some(ref validator) = self.jwt_validator {
317 match validator.validate_token(token) {
318 Ok(_claims) => return AuthResult::Authenticated,
319 Err(_) => {
320 }
322 }
323 }
324
325 if self.verify_key(token) {
327 AuthResult::Authenticated
328 } else {
329 AuthResult::InvalidCredentials
330 }
331 }
332}
333
334enum AuthResult {
336 Authenticated,
337 InvalidCredentials,
338}
339
340fn fast_hash(s: &str) -> u64 {
342 let mut hash: u64 = 0xcbf29ce484222325;
344 for byte in s.bytes() {
345 hash ^= byte as u64;
346 hash = hash.wrapping_mul(0x100000001b3);
347 }
348 hash
349}
350
351pub async fn auth_middleware(
357 axum::Extension(config): axum::Extension<AuthConfig>,
358 request: Request<Body>,
359 next: Next,
360) -> Response {
361 if !config.enabled {
363 return next.run(request).await;
364 }
365
366 let path = request.uri().path();
368 if config.exempt_paths.contains(path) {
369 return next.run(request).await;
370 }
371
372 let bearer_token = extract_bearer_token(&request);
374 let api_key = extract_x_api_key(&request);
375
376 if let Some(token) = bearer_token {
378 return match config.verify_bearer(&token) {
379 AuthResult::Authenticated => next.run(request).await,
380 AuthResult::InvalidCredentials => (
381 StatusCode::UNAUTHORIZED,
382 [(header::WWW_AUTHENTICATE, "Bearer")],
383 "Invalid credentials",
384 )
385 .into_response(),
386 };
387 }
388
389 if let Some(key) = api_key {
391 if config.verify_key(&key) {
392 return next.run(request).await;
393 }
394 return (
395 StatusCode::UNAUTHORIZED,
396 [(header::WWW_AUTHENTICATE, "Bearer")],
397 "Invalid API key",
398 )
399 .into_response();
400 }
401
402 (
404 StatusCode::UNAUTHORIZED,
405 [(header::WWW_AUTHENTICATE, "Bearer")],
406 "API key required. Provide via 'Authorization: Bearer <key>' or 'X-API-Key' header",
407 )
408 .into_response()
409}
410
411fn extract_bearer_token(request: &Request<Body>) -> Option<String> {
413 request
414 .headers()
415 .get(header::AUTHORIZATION)
416 .and_then(|h| h.to_str().ok())
417 .and_then(|s| s.strip_prefix("Bearer "))
418 .map(|s| s.to_string())
419}
420
421fn extract_x_api_key(request: &Request<Body>) -> Option<String> {
423 request
424 .headers()
425 .get("X-API-Key")
426 .and_then(|h| h.to_str().ok())
427 .map(|s| s.to_string())
428}
429
430#[cfg(test)]
431#[allow(clippy::unwrap_used)]
432mod tests {
433 use super::*;
434 use axum::{
435 body::Body,
436 http::{Request, StatusCode},
437 middleware,
438 routing::get,
439 Router,
440 };
441 use tower::ServiceExt;
442
443 async fn test_handler() -> &'static str {
444 "ok"
445 }
446
447 fn test_router(config: AuthConfig) -> Router {
448 Router::new()
449 .route("/api/test", get(test_handler))
450 .route("/health", get(test_handler))
451 .layer(middleware::from_fn(auth_middleware))
452 .layer(axum::Extension(config))
453 }
454
455 #[tokio::test]
456 async fn test_auth_disabled() {
457 let config = AuthConfig::default();
458 let router = test_router(config);
459
460 let request = Request::builder()
461 .uri("/api/test")
462 .body(Body::empty())
463 .unwrap();
464
465 let response = router.oneshot(request).await.unwrap();
466 assert_eq!(response.status(), StatusCode::OK);
467 }
468
469 #[tokio::test]
470 async fn test_valid_bearer_token() {
471 let config = AuthConfig::with_api_keys(vec!["test-key-123".to_string()]);
472 let router = test_router(config);
473
474 let request = Request::builder()
475 .uri("/api/test")
476 .header("Authorization", "Bearer test-key-123")
477 .body(Body::empty())
478 .unwrap();
479
480 let response = router.oneshot(request).await.unwrap();
481 assert_eq!(response.status(), StatusCode::OK);
482 }
483
484 #[tokio::test]
485 async fn test_valid_x_api_key() {
486 let config = AuthConfig::with_api_keys(vec!["test-key-456".to_string()]);
487 let router = test_router(config);
488
489 let request = Request::builder()
490 .uri("/api/test")
491 .header("X-API-Key", "test-key-456")
492 .body(Body::empty())
493 .unwrap();
494
495 let response = router.oneshot(request).await.unwrap();
496 assert_eq!(response.status(), StatusCode::OK);
497 }
498
499 #[tokio::test]
500 async fn test_invalid_api_key() {
501 let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
502 let router = test_router(config);
503
504 let request = Request::builder()
505 .uri("/api/test")
506 .header("Authorization", "Bearer wrong-key")
507 .body(Body::empty())
508 .unwrap();
509
510 let response = router.oneshot(request).await.unwrap();
511 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
512 }
513
514 #[tokio::test]
515 async fn test_missing_api_key() {
516 let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
517 let router = test_router(config);
518
519 let request = Request::builder()
520 .uri("/api/test")
521 .body(Body::empty())
522 .unwrap();
523
524 let response = router.oneshot(request).await.unwrap();
525 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
526 }
527
528 #[tokio::test]
529 async fn test_exempt_path() {
530 let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
531 let router = test_router(config);
532
533 let request = Request::builder()
534 .uri("/health")
535 .body(Body::empty())
536 .unwrap();
537
538 let response = router.oneshot(request).await.unwrap();
539 assert_eq!(response.status(), StatusCode::OK);
540 }
541
542 #[tokio::test]
543 async fn test_prehashed_keys() {
544 let argon2 = Argon2::default();
546 let salt = SaltString::generate(&mut OsRng);
547 let hash = argon2
548 .hash_password(b"pre-hashed-key", &salt)
549 .unwrap()
550 .to_string();
551
552 let config = AuthConfig::with_prehashed_keys(vec![hash]);
553 let router = test_router(config);
554
555 let request = Request::builder()
556 .uri("/api/test")
557 .header("Authorization", "Bearer pre-hashed-key")
558 .body(Body::empty())
559 .unwrap();
560
561 let response = router.oneshot(request).await.unwrap();
562 assert_eq!(response.status(), StatusCode::OK);
563 }
564
565 #[tokio::test]
566 async fn test_cache_hit() {
567 let config = AuthConfig::with_api_keys(vec!["cached-key".to_string()]);
568
569 let router1 = test_router(config.clone());
571 let request1 = Request::builder()
572 .uri("/api/test")
573 .header("Authorization", "Bearer cached-key")
574 .body(Body::empty())
575 .unwrap();
576 let response1 = router1.oneshot(request1).await.unwrap();
577 assert_eq!(response1.status(), StatusCode::OK);
578
579 let router2 = test_router(config);
581 let request2 = Request::builder()
582 .uri("/api/test")
583 .header("Authorization", "Bearer cached-key")
584 .body(Body::empty())
585 .unwrap();
586 let response2 = router2.oneshot(request2).await.unwrap();
587 assert_eq!(response2.status(), StatusCode::OK);
588 }
589
590 #[tokio::test]
591 async fn test_api_key_fallback_still_works() {
592 let config = AuthConfig::with_api_keys(vec!["my-key".to_string()]);
594 let router = test_router(config);
595
596 let request = Request::builder()
597 .uri("/api/test")
598 .header("Authorization", "Bearer my-key")
599 .body(Body::empty())
600 .unwrap();
601
602 let response = router.oneshot(request).await.unwrap();
603 assert_eq!(response.status(), StatusCode::OK);
604 }
605
606 #[cfg(feature = "jwt")]
607 mod jwt_tests {
608 use super::*;
609
610 #[test]
611 fn test_jwt_config_creation() {
612 let config =
613 JwtConfig::new("https://auth.example.com".to_string(), "my-api".to_string());
614 assert_eq!(config.issuer, "https://auth.example.com");
615 assert_eq!(config.audience, "my-api");
616 assert!(config.public_key_pem.is_none());
617 assert_eq!(
618 config.allowed_algorithms,
619 vec![jsonwebtoken::Algorithm::RS256]
620 );
621 }
622
623 #[test]
624 fn test_jwt_validator_requires_key() {
625 let config = JwtConfig::new("issuer".to_string(), "audience".to_string());
626 let validator = JwtValidator::new(config).expect("should create");
627 let result = validator.validate_token("some.invalid.token");
628 assert!(result.is_err());
629 }
630
631 #[test]
632 fn test_token_claims_deserialization() {
633 let json = r#"{
634 "sub": "user123",
635 "email": "user@example.com",
636 "roles": ["admin", "operator"],
637 "tenant_id": "tenant1",
638 "exp": 9999999999,
639 "iss": "https://auth.example.com"
640 }"#;
641 let claims: TokenClaims = serde_json::from_str(json).unwrap();
642 assert_eq!(claims.sub, "user123");
643 assert_eq!(claims.email, Some("user@example.com".to_string()));
644 assert_eq!(claims.roles, vec!["admin", "operator"]);
645 assert_eq!(claims.tenant_id, Some("tenant1".to_string()));
646 }
647 }
648}