Skip to main content

datasynth_server/rest/
auth.rs

1//! Authentication middleware for REST API.
2//!
3//! Provides API key authentication with Argon2id hashing and
4//! timing-safe comparison for protecting endpoints.
5//!
6//! When the `jwt` feature is enabled, also supports JWT validation
7//! from external OIDC providers (Keycloak, Auth0, Entra ID).
8
9use argon2::{
10    password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
11    Argon2,
12};
13use axum::{
14    body::Body,
15    http::{header, Request, StatusCode},
16    middleware::Next,
17    response::{IntoResponse, Response},
18};
19use std::collections::HashSet;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, Instant};
22
23// ===========================================================================
24// JWT types (feature-gated)
25// ===========================================================================
26
27/// JWT validation configuration for OIDC providers.
28#[cfg(feature = "jwt")]
29#[derive(Clone, Debug)]
30pub struct JwtConfig {
31    /// Expected token issuer (e.g., "https://auth.example.com/realms/main").
32    pub issuer: String,
33    /// Expected audience claim.
34    pub audience: String,
35    /// PEM-encoded public key for RS256 verification.
36    pub public_key_pem: Option<String>,
37    /// Allowed algorithms (default: RS256).
38    pub allowed_algorithms: Vec<jsonwebtoken::Algorithm>,
39}
40
41#[cfg(feature = "jwt")]
42impl JwtConfig {
43    /// Create a new JWT config with RS256 algorithm.
44    pub fn new(issuer: String, audience: String) -> Self {
45        Self {
46            issuer,
47            audience,
48            public_key_pem: None,
49            allowed_algorithms: vec![jsonwebtoken::Algorithm::RS256],
50        }
51    }
52
53    /// Set the PEM public key.
54    pub fn with_public_key(mut self, pem: String) -> Self {
55        self.public_key_pem = Some(pem);
56        self
57    }
58}
59
60/// Claims extracted from a validated JWT.
61#[cfg(feature = "jwt")]
62#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct TokenClaims {
64    /// Subject (user ID).
65    pub sub: String,
66    /// Email address (optional).
67    #[serde(default)]
68    pub email: Option<String>,
69    /// Roles assigned to the user.
70    #[serde(default)]
71    pub roles: Vec<String>,
72    /// Tenant ID for multi-tenancy (optional).
73    #[serde(default)]
74    pub tenant_id: Option<String>,
75    /// Expiration timestamp.
76    pub exp: usize,
77    /// Issuer.
78    pub iss: String,
79    /// Audience (can be string or array).
80    #[serde(default)]
81    pub aud: Option<serde_json::Value>,
82}
83
84/// JWT validator that verifies tokens from external OIDC providers.
85#[cfg(feature = "jwt")]
86#[derive(Clone)]
87pub struct JwtValidator {
88    config: JwtConfig,
89    decoding_key: Option<jsonwebtoken::DecodingKey>,
90}
91
92#[cfg(feature = "jwt")]
93impl std::fmt::Debug for JwtValidator {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        f.debug_struct("JwtValidator")
96            .field("config", &self.config)
97            .field(
98                "decoding_key",
99                &self.decoding_key.as_ref().map(|_| "[redacted]"),
100            )
101            .finish()
102    }
103}
104
105#[cfg(feature = "jwt")]
106impl JwtValidator {
107    /// Create a new JWT validator.
108    pub fn new(config: JwtConfig) -> Result<Self, String> {
109        let decoding_key = if let Some(ref pem) = config.public_key_pem {
110            Some(
111                jsonwebtoken::DecodingKey::from_rsa_pem(pem.as_bytes())
112                    .map_err(|e| format!("Invalid RSA PEM key: {}", e))?,
113            )
114        } else {
115            None
116        };
117
118        Ok(Self {
119            config,
120            decoding_key,
121        })
122    }
123
124    /// Validate a JWT token and extract claims.
125    pub fn validate_token(&self, token: &str) -> Result<TokenClaims, String> {
126        let decoding_key = self
127            .decoding_key
128            .as_ref()
129            .ok_or_else(|| "No decoding key configured".to_string())?;
130
131        let mut validation = jsonwebtoken::Validation::new(
132            *self
133                .config
134                .allowed_algorithms
135                .first()
136                .unwrap_or(&jsonwebtoken::Algorithm::RS256),
137        );
138        validation.set_issuer(&[&self.config.issuer]);
139        validation.set_audience(&[&self.config.audience]);
140        validation.validate_exp = true;
141
142        let token_data = jsonwebtoken::decode::<TokenClaims>(token, decoding_key, &validation)
143            .map_err(|e| format!("JWT validation failed: {}", e))?;
144
145        Ok(token_data.claims)
146    }
147}
148
149// ===========================================================================
150// Authentication configuration
151// ===========================================================================
152
153/// Authentication configuration.
154#[derive(Clone, Debug)]
155pub struct AuthConfig {
156    /// Whether authentication is enabled.
157    pub enabled: bool,
158    /// Argon2id hashed API keys (PHC format strings).
159    hashed_keys: Vec<String>,
160    /// Paths that don't require authentication (e.g., health checks).
161    pub exempt_paths: HashSet<String>,
162    /// LRU cache for recently verified keys (fast hash -> expiry).
163    cache: Arc<Mutex<Vec<CacheEntry>>>,
164    /// JWT validator (only available with `jwt` feature).
165    #[cfg(feature = "jwt")]
166    pub jwt_validator: Option<JwtValidator>,
167}
168
169#[derive(Clone, Debug)]
170struct CacheEntry {
171    /// Fast hash of the submitted key (not the key itself).
172    key_hash: u64,
173    /// When this cache entry expires.
174    expires_at: Instant,
175}
176
177impl Default for AuthConfig {
178    fn default() -> Self {
179        Self {
180            enabled: false,
181            hashed_keys: Vec::new(),
182            exempt_paths: HashSet::from([
183                "/health".to_string(),
184                "/ready".to_string(),
185                "/live".to_string(),
186                "/metrics".to_string(),
187            ]),
188            cache: Arc::new(Mutex::new(Vec::new())),
189            #[cfg(feature = "jwt")]
190            jwt_validator: None,
191        }
192    }
193}
194
195impl AuthConfig {
196    /// Create a new auth config with API key authentication enabled.
197    ///
198    /// Keys are hashed with Argon2id at construction time.
199    pub fn with_api_keys(api_keys: Vec<String>) -> Self {
200        let argon2 = Argon2::default();
201        let hashed_keys: Vec<String> = api_keys
202            .iter()
203            .map(|key| {
204                let salt = SaltString::generate(&mut OsRng);
205                argon2
206                    .hash_password(key.as_bytes(), &salt)
207                    .expect("Argon2id hashing should not fail")
208                    .to_string()
209            })
210            .collect();
211
212        Self {
213            enabled: true,
214            hashed_keys,
215            exempt_paths: HashSet::from([
216                "/health".to_string(),
217                "/ready".to_string(),
218                "/live".to_string(),
219                "/metrics".to_string(),
220            ]),
221            cache: Arc::new(Mutex::new(Vec::new())),
222            #[cfg(feature = "jwt")]
223            jwt_validator: None,
224        }
225    }
226
227    /// Create a new auth config with pre-hashed keys (PHC format).
228    ///
229    /// Use this when keys are already hashed (e.g., loaded from config).
230    pub fn with_prehashed_keys(hashed_keys: Vec<String>) -> Self {
231        Self {
232            enabled: true,
233            hashed_keys,
234            exempt_paths: HashSet::from([
235                "/health".to_string(),
236                "/ready".to_string(),
237                "/live".to_string(),
238                "/metrics".to_string(),
239            ]),
240            cache: Arc::new(Mutex::new(Vec::new())),
241            #[cfg(feature = "jwt")]
242            jwt_validator: None,
243        }
244    }
245
246    /// Add JWT validation support.
247    #[cfg(feature = "jwt")]
248    pub fn with_jwt(mut self, config: JwtConfig) -> Result<Self, String> {
249        let validator = JwtValidator::new(config)?;
250        self.jwt_validator = Some(validator);
251        self.enabled = true;
252        Ok(self)
253    }
254
255    /// Add exempt paths that don't require authentication.
256    pub fn with_exempt_paths(mut self, paths: Vec<String>) -> Self {
257        for path in paths {
258            self.exempt_paths.insert(path);
259        }
260        self
261    }
262
263    /// Verify an API key against all stored hashes.
264    ///
265    /// Iterates ALL hashes to prevent timing side-channels on which
266    /// key matched or how many keys exist.
267    fn verify_key(&self, submitted_key: &str) -> bool {
268        let key_hash = fast_hash(submitted_key);
269
270        // Check cache first
271        {
272            let cache = self.cache.lock().unwrap();
273            let now = Instant::now();
274            for entry in cache.iter() {
275                if entry.key_hash == key_hash && entry.expires_at > now {
276                    return true;
277                }
278            }
279        }
280
281        // Verify against all hashed keys (no short-circuit)
282        let argon2 = Argon2::default();
283        let mut any_match = false;
284
285        for stored_hash in &self.hashed_keys {
286            if let Ok(parsed_hash) = PasswordHash::new(stored_hash) {
287                if argon2
288                    .verify_password(submitted_key.as_bytes(), &parsed_hash)
289                    .is_ok()
290                {
291                    any_match = true;
292                }
293            }
294        }
295
296        // Cache on success
297        if any_match {
298            let mut cache = self.cache.lock().unwrap();
299            // Evict expired entries
300            let now = Instant::now();
301            cache.retain(|e| e.expires_at > now);
302            // Add new entry with 5s TTL
303            cache.push(CacheEntry {
304                key_hash,
305                expires_at: now + Duration::from_secs(5),
306            });
307        }
308
309        any_match
310    }
311
312    /// Try to validate a Bearer token as JWT first, then fall back to API key.
313    fn verify_bearer(&self, token: &str) -> AuthResult {
314        // Try JWT first (if feature enabled and configured)
315        #[cfg(feature = "jwt")]
316        if let Some(ref validator) = self.jwt_validator {
317            match validator.validate_token(token) {
318                Ok(_claims) => return AuthResult::Authenticated,
319                Err(_) => {
320                    // JWT validation failed — fall through to API key check
321                }
322            }
323        }
324
325        // Fall back to API key verification
326        if self.verify_key(token) {
327            AuthResult::Authenticated
328        } else {
329            AuthResult::InvalidCredentials
330        }
331    }
332}
333
334/// Result of an authentication attempt.
335enum AuthResult {
336    Authenticated,
337    InvalidCredentials,
338}
339
340/// Fast non-cryptographic hash for cache key lookup.
341fn fast_hash(s: &str) -> u64 {
342    // FNV-1a hash
343    let mut hash: u64 = 0xcbf29ce484222325;
344    for byte in s.bytes() {
345        hash ^= byte as u64;
346        hash = hash.wrapping_mul(0x100000001b3);
347    }
348    hash
349}
350
351/// Authentication middleware that checks for valid API key or JWT.
352///
353/// Checks for credentials in:
354/// 1. `Authorization: Bearer <key_or_jwt>` header
355/// 2. `X-API-Key: <key>` header
356pub async fn auth_middleware(
357    axum::Extension(config): axum::Extension<AuthConfig>,
358    request: Request<Body>,
359    next: Next,
360) -> Response {
361    // Skip if auth is disabled
362    if !config.enabled {
363        return next.run(request).await;
364    }
365
366    // Check if path is exempt
367    let path = request.uri().path();
368    if config.exempt_paths.contains(path) {
369        return next.run(request).await;
370    }
371
372    // Extract credential from headers
373    let bearer_token = extract_bearer_token(&request);
374    let api_key = extract_x_api_key(&request);
375
376    // Try Bearer token first (supports both JWT and API key)
377    if let Some(token) = bearer_token {
378        return match config.verify_bearer(&token) {
379            AuthResult::Authenticated => next.run(request).await,
380            AuthResult::InvalidCredentials => (
381                StatusCode::UNAUTHORIZED,
382                [(header::WWW_AUTHENTICATE, "Bearer")],
383                "Invalid credentials",
384            )
385                .into_response(),
386        };
387    }
388
389    // Try X-API-Key header
390    if let Some(key) = api_key {
391        if config.verify_key(&key) {
392            return next.run(request).await;
393        }
394        return (
395            StatusCode::UNAUTHORIZED,
396            [(header::WWW_AUTHENTICATE, "Bearer")],
397            "Invalid API key",
398        )
399            .into_response();
400    }
401
402    // No credentials provided
403    (
404        StatusCode::UNAUTHORIZED,
405        [(header::WWW_AUTHENTICATE, "Bearer")],
406        "API key required. Provide via 'Authorization: Bearer <key>' or 'X-API-Key' header",
407    )
408        .into_response()
409}
410
411/// Extract Bearer token from Authorization header.
412fn extract_bearer_token(request: &Request<Body>) -> Option<String> {
413    request
414        .headers()
415        .get(header::AUTHORIZATION)
416        .and_then(|h| h.to_str().ok())
417        .and_then(|s| s.strip_prefix("Bearer "))
418        .map(|s| s.to_string())
419}
420
421/// Extract API key from X-API-Key header.
422fn extract_x_api_key(request: &Request<Body>) -> Option<String> {
423    request
424        .headers()
425        .get("X-API-Key")
426        .and_then(|h| h.to_str().ok())
427        .map(|s| s.to_string())
428}
429
430#[cfg(test)]
431#[allow(clippy::unwrap_used)]
432mod tests {
433    use super::*;
434    use axum::{
435        body::Body,
436        http::{Request, StatusCode},
437        middleware,
438        routing::get,
439        Router,
440    };
441    use tower::ServiceExt;
442
443    async fn test_handler() -> &'static str {
444        "ok"
445    }
446
447    fn test_router(config: AuthConfig) -> Router {
448        Router::new()
449            .route("/api/test", get(test_handler))
450            .route("/health", get(test_handler))
451            .layer(middleware::from_fn(auth_middleware))
452            .layer(axum::Extension(config))
453    }
454
455    #[tokio::test]
456    async fn test_auth_disabled() {
457        let config = AuthConfig::default();
458        let router = test_router(config);
459
460        let request = Request::builder()
461            .uri("/api/test")
462            .body(Body::empty())
463            .unwrap();
464
465        let response = router.oneshot(request).await.unwrap();
466        assert_eq!(response.status(), StatusCode::OK);
467    }
468
469    #[tokio::test]
470    async fn test_valid_bearer_token() {
471        let config = AuthConfig::with_api_keys(vec!["test-key-123".to_string()]);
472        let router = test_router(config);
473
474        let request = Request::builder()
475            .uri("/api/test")
476            .header("Authorization", "Bearer test-key-123")
477            .body(Body::empty())
478            .unwrap();
479
480        let response = router.oneshot(request).await.unwrap();
481        assert_eq!(response.status(), StatusCode::OK);
482    }
483
484    #[tokio::test]
485    async fn test_valid_x_api_key() {
486        let config = AuthConfig::with_api_keys(vec!["test-key-456".to_string()]);
487        let router = test_router(config);
488
489        let request = Request::builder()
490            .uri("/api/test")
491            .header("X-API-Key", "test-key-456")
492            .body(Body::empty())
493            .unwrap();
494
495        let response = router.oneshot(request).await.unwrap();
496        assert_eq!(response.status(), StatusCode::OK);
497    }
498
499    #[tokio::test]
500    async fn test_invalid_api_key() {
501        let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
502        let router = test_router(config);
503
504        let request = Request::builder()
505            .uri("/api/test")
506            .header("Authorization", "Bearer wrong-key")
507            .body(Body::empty())
508            .unwrap();
509
510        let response = router.oneshot(request).await.unwrap();
511        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
512    }
513
514    #[tokio::test]
515    async fn test_missing_api_key() {
516        let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
517        let router = test_router(config);
518
519        let request = Request::builder()
520            .uri("/api/test")
521            .body(Body::empty())
522            .unwrap();
523
524        let response = router.oneshot(request).await.unwrap();
525        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
526    }
527
528    #[tokio::test]
529    async fn test_exempt_path() {
530        let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
531        let router = test_router(config);
532
533        let request = Request::builder()
534            .uri("/health")
535            .body(Body::empty())
536            .unwrap();
537
538        let response = router.oneshot(request).await.unwrap();
539        assert_eq!(response.status(), StatusCode::OK);
540    }
541
542    #[tokio::test]
543    async fn test_prehashed_keys() {
544        // Hash a key manually
545        let argon2 = Argon2::default();
546        let salt = SaltString::generate(&mut OsRng);
547        let hash = argon2
548            .hash_password(b"pre-hashed-key", &salt)
549            .unwrap()
550            .to_string();
551
552        let config = AuthConfig::with_prehashed_keys(vec![hash]);
553        let router = test_router(config);
554
555        let request = Request::builder()
556            .uri("/api/test")
557            .header("Authorization", "Bearer pre-hashed-key")
558            .body(Body::empty())
559            .unwrap();
560
561        let response = router.oneshot(request).await.unwrap();
562        assert_eq!(response.status(), StatusCode::OK);
563    }
564
565    #[tokio::test]
566    async fn test_cache_hit() {
567        let config = AuthConfig::with_api_keys(vec!["cached-key".to_string()]);
568
569        // First request - populates cache
570        let router1 = test_router(config.clone());
571        let request1 = Request::builder()
572            .uri("/api/test")
573            .header("Authorization", "Bearer cached-key")
574            .body(Body::empty())
575            .unwrap();
576        let response1 = router1.oneshot(request1).await.unwrap();
577        assert_eq!(response1.status(), StatusCode::OK);
578
579        // Second request - should hit cache
580        let router2 = test_router(config);
581        let request2 = Request::builder()
582            .uri("/api/test")
583            .header("Authorization", "Bearer cached-key")
584            .body(Body::empty())
585            .unwrap();
586        let response2 = router2.oneshot(request2).await.unwrap();
587        assert_eq!(response2.status(), StatusCode::OK);
588    }
589
590    #[tokio::test]
591    async fn test_api_key_fallback_still_works() {
592        // Even without JWT feature, API key auth should work
593        let config = AuthConfig::with_api_keys(vec!["my-key".to_string()]);
594        let router = test_router(config);
595
596        let request = Request::builder()
597            .uri("/api/test")
598            .header("Authorization", "Bearer my-key")
599            .body(Body::empty())
600            .unwrap();
601
602        let response = router.oneshot(request).await.unwrap();
603        assert_eq!(response.status(), StatusCode::OK);
604    }
605
606    #[cfg(feature = "jwt")]
607    mod jwt_tests {
608        use super::*;
609
610        #[test]
611        fn test_jwt_config_creation() {
612            let config =
613                JwtConfig::new("https://auth.example.com".to_string(), "my-api".to_string());
614            assert_eq!(config.issuer, "https://auth.example.com");
615            assert_eq!(config.audience, "my-api");
616            assert!(config.public_key_pem.is_none());
617            assert_eq!(
618                config.allowed_algorithms,
619                vec![jsonwebtoken::Algorithm::RS256]
620            );
621        }
622
623        #[test]
624        fn test_jwt_validator_requires_key() {
625            let config = JwtConfig::new("issuer".to_string(), "audience".to_string());
626            let validator = JwtValidator::new(config).expect("should create");
627            let result = validator.validate_token("some.invalid.token");
628            assert!(result.is_err());
629        }
630
631        #[test]
632        fn test_token_claims_deserialization() {
633            let json = r#"{
634                "sub": "user123",
635                "email": "user@example.com",
636                "roles": ["admin", "operator"],
637                "tenant_id": "tenant1",
638                "exp": 9999999999,
639                "iss": "https://auth.example.com"
640            }"#;
641            let claims: TokenClaims = serde_json::from_str(json).unwrap();
642            assert_eq!(claims.sub, "user123");
643            assert_eq!(claims.email, Some("user@example.com".to_string()));
644            assert_eq!(claims.roles, vec!["admin", "operator"]);
645            assert_eq!(claims.tenant_id, Some("tenant1".to_string()));
646        }
647    }
648}