Skip to main content

datasynth_server/rest/
auth.rs

1//! Authentication middleware for REST API.
2//!
3//! Provides API key authentication with Argon2id hashing and
4//! timing-safe comparison for protecting endpoints.
5//!
6//! When the `jwt` feature is enabled, also supports JWT validation
7//! from external OIDC providers (Keycloak, Auth0, Entra ID).
8
9use argon2::{
10    password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
11    Argon2,
12};
13use axum::{
14    body::Body,
15    http::{header, Request, StatusCode},
16    middleware::Next,
17    response::{IntoResponse, Response},
18};
19use std::collections::HashSet;
20use std::sync::{Arc, Mutex};
21use std::time::{Duration, Instant};
22
23// ===========================================================================
24// JWT types (feature-gated)
25// ===========================================================================
26
27/// JWT validation configuration for OIDC providers.
28#[cfg(feature = "jwt")]
29#[derive(Clone, Debug)]
30pub struct JwtConfig {
31    /// Expected token issuer (e.g., "https://auth.example.com/realms/main").
32    pub issuer: String,
33    /// Expected audience claim.
34    pub audience: String,
35    /// PEM-encoded public key for RS256 verification.
36    pub public_key_pem: Option<String>,
37    /// Allowed algorithms (default: RS256).
38    pub allowed_algorithms: Vec<jsonwebtoken::Algorithm>,
39}
40
41#[cfg(feature = "jwt")]
42impl JwtConfig {
43    /// Create a new JWT config with RS256 algorithm.
44    pub fn new(issuer: String, audience: String) -> Self {
45        Self {
46            issuer,
47            audience,
48            public_key_pem: None,
49            allowed_algorithms: vec![jsonwebtoken::Algorithm::RS256],
50        }
51    }
52
53    /// Set the PEM public key.
54    pub fn with_public_key(mut self, pem: String) -> Self {
55        self.public_key_pem = Some(pem);
56        self
57    }
58}
59
60/// Claims extracted from a validated JWT.
61#[cfg(feature = "jwt")]
62#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct TokenClaims {
64    /// Subject (user ID).
65    pub sub: String,
66    /// Email address (optional).
67    #[serde(default)]
68    pub email: Option<String>,
69    /// Roles assigned to the user.
70    #[serde(default)]
71    pub roles: Vec<String>,
72    /// Tenant ID for multi-tenancy (optional).
73    #[serde(default)]
74    pub tenant_id: Option<String>,
75    /// Expiration timestamp.
76    pub exp: usize,
77    /// Issuer.
78    pub iss: String,
79    /// Audience (can be string or array).
80    #[serde(default)]
81    pub aud: Option<serde_json::Value>,
82}
83
84/// JWT validator that verifies tokens from external OIDC providers.
85#[cfg(feature = "jwt")]
86#[derive(Clone)]
87pub struct JwtValidator {
88    config: JwtConfig,
89    decoding_key: Option<jsonwebtoken::DecodingKey>,
90}
91
92#[cfg(feature = "jwt")]
93impl std::fmt::Debug for JwtValidator {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        f.debug_struct("JwtValidator")
96            .field("config", &self.config)
97            .field(
98                "decoding_key",
99                &self.decoding_key.as_ref().map(|_| "[redacted]"),
100            )
101            .finish()
102    }
103}
104
105#[cfg(feature = "jwt")]
106impl JwtValidator {
107    /// Create a new JWT validator.
108    pub fn new(config: JwtConfig) -> Result<Self, String> {
109        let decoding_key = if let Some(ref pem) = config.public_key_pem {
110            Some(
111                jsonwebtoken::DecodingKey::from_rsa_pem(pem.as_bytes())
112                    .map_err(|e| format!("Invalid RSA PEM key: {}", e))?,
113            )
114        } else {
115            None
116        };
117
118        Ok(Self {
119            config,
120            decoding_key,
121        })
122    }
123
124    /// Validate a JWT token and extract claims.
125    pub fn validate_token(&self, token: &str) -> Result<TokenClaims, String> {
126        let decoding_key = self
127            .decoding_key
128            .as_ref()
129            .ok_or_else(|| "No decoding key configured".to_string())?;
130
131        let mut validation = jsonwebtoken::Validation::new(
132            *self
133                .config
134                .allowed_algorithms
135                .first()
136                .unwrap_or(&jsonwebtoken::Algorithm::RS256),
137        );
138        validation.set_issuer(&[&self.config.issuer]);
139        validation.set_audience(&[&self.config.audience]);
140        validation.validate_exp = true;
141
142        let token_data = jsonwebtoken::decode::<TokenClaims>(token, decoding_key, &validation)
143            .map_err(|e| format!("JWT validation failed: {}", e))?;
144
145        Ok(token_data.claims)
146    }
147}
148
149// ===========================================================================
150// Authentication configuration
151// ===========================================================================
152
153/// Authentication configuration.
154#[derive(Clone, Debug)]
155pub struct AuthConfig {
156    /// Whether authentication is enabled.
157    pub enabled: bool,
158    /// Argon2id hashed API keys (PHC format strings).
159    hashed_keys: Vec<String>,
160    /// Paths that don't require authentication (e.g., health checks).
161    pub exempt_paths: HashSet<String>,
162    /// LRU cache for recently verified keys (fast hash -> expiry).
163    cache: Arc<Mutex<Vec<CacheEntry>>>,
164    /// JWT validator (only available with `jwt` feature).
165    #[cfg(feature = "jwt")]
166    pub jwt_validator: Option<JwtValidator>,
167}
168
169#[derive(Clone, Debug)]
170struct CacheEntry {
171    /// Fast hash of the submitted key (not the key itself).
172    key_hash: u64,
173    /// When this cache entry expires.
174    expires_at: Instant,
175}
176
177impl Default for AuthConfig {
178    fn default() -> Self {
179        Self {
180            enabled: false,
181            hashed_keys: Vec::new(),
182            exempt_paths: HashSet::from([
183                "/health".to_string(),
184                "/ready".to_string(),
185                "/live".to_string(),
186                "/metrics".to_string(),
187            ]),
188            cache: Arc::new(Mutex::new(Vec::new())),
189            #[cfg(feature = "jwt")]
190            jwt_validator: None,
191        }
192    }
193}
194
195impl AuthConfig {
196    /// Create a new auth config with API key authentication enabled.
197    ///
198    /// Keys are hashed with Argon2id at construction time.
199    pub fn with_api_keys(api_keys: Vec<String>) -> Self {
200        let argon2 = Argon2::default();
201        let hashed_keys: Vec<String> = api_keys
202            .iter()
203            .map(|key| {
204                let salt = SaltString::generate(&mut OsRng);
205                argon2
206                    .hash_password(key.as_bytes(), &salt)
207                    .expect("Argon2id hashing should not fail")
208                    .to_string()
209            })
210            .collect();
211
212        Self {
213            enabled: true,
214            hashed_keys,
215            exempt_paths: HashSet::from([
216                "/health".to_string(),
217                "/ready".to_string(),
218                "/live".to_string(),
219                "/metrics".to_string(),
220            ]),
221            cache: Arc::new(Mutex::new(Vec::new())),
222            #[cfg(feature = "jwt")]
223            jwt_validator: None,
224        }
225    }
226
227    /// Create a new auth config with pre-hashed keys (PHC format).
228    ///
229    /// Use this when keys are already hashed (e.g., loaded from config).
230    pub fn with_prehashed_keys(hashed_keys: Vec<String>) -> Self {
231        Self {
232            enabled: true,
233            hashed_keys,
234            exempt_paths: HashSet::from([
235                "/health".to_string(),
236                "/ready".to_string(),
237                "/live".to_string(),
238                "/metrics".to_string(),
239            ]),
240            cache: Arc::new(Mutex::new(Vec::new())),
241            #[cfg(feature = "jwt")]
242            jwt_validator: None,
243        }
244    }
245
246    /// Add JWT validation support.
247    #[cfg(feature = "jwt")]
248    pub fn with_jwt(mut self, config: JwtConfig) -> Result<Self, String> {
249        let validator = JwtValidator::new(config)?;
250        self.jwt_validator = Some(validator);
251        self.enabled = true;
252        Ok(self)
253    }
254
255    /// Add exempt paths that don't require authentication.
256    pub fn with_exempt_paths(mut self, paths: Vec<String>) -> Self {
257        for path in paths {
258            self.exempt_paths.insert(path);
259        }
260        self
261    }
262
263    /// Verify an API key against all stored hashes.
264    ///
265    /// Iterates ALL hashes to prevent timing side-channels on which
266    /// key matched or how many keys exist.
267    fn verify_key(&self, submitted_key: &str) -> bool {
268        let key_hash = fast_hash(submitted_key);
269
270        // Check cache first
271        {
272            let cache = self.cache.lock().unwrap_or_else(|e| e.into_inner());
273            let now = Instant::now();
274            for entry in cache.iter() {
275                if entry.key_hash == key_hash && entry.expires_at > now {
276                    return true;
277                }
278            }
279        }
280
281        // Verify against all hashed keys (no short-circuit)
282        let argon2 = Argon2::default();
283        let mut any_match = false;
284
285        for stored_hash in &self.hashed_keys {
286            if let Ok(parsed_hash) = PasswordHash::new(stored_hash) {
287                if argon2
288                    .verify_password(submitted_key.as_bytes(), &parsed_hash)
289                    .is_ok()
290                {
291                    any_match = true;
292                }
293            }
294        }
295
296        // Cache on success
297        if any_match {
298            let mut cache = self.cache.lock().unwrap_or_else(|e| e.into_inner());
299            // Evict expired entries
300            let now = Instant::now();
301            cache.retain(|e| e.expires_at > now);
302            // Add new entry with 5s TTL
303            cache.push(CacheEntry {
304                key_hash,
305                expires_at: now + Duration::from_secs(5),
306            });
307        }
308
309        any_match
310    }
311
312    /// Try to validate a Bearer token as JWT first, then fall back to API key.
313    fn verify_bearer(&self, token: &str) -> AuthResult {
314        // Try JWT first (if feature enabled and configured)
315        #[cfg(feature = "jwt")]
316        if let Some(ref validator) = self.jwt_validator {
317            match validator.validate_token(token) {
318                Ok(_claims) => return AuthResult::Authenticated,
319                Err(_) => {
320                    // JWT validation failed — fall through to API key check
321                }
322            }
323        }
324
325        // Fall back to API key verification
326        if self.verify_key(token) {
327            AuthResult::Authenticated
328        } else {
329            AuthResult::InvalidCredentials
330        }
331    }
332}
333
334/// Result of an authentication attempt.
335enum AuthResult {
336    Authenticated,
337    InvalidCredentials,
338}
339
340/// Fast non-cryptographic hash for cache key lookup.
341fn fast_hash(s: &str) -> u64 {
342    // FNV-1a hash
343    let mut hash: u64 = 0xcbf29ce484222325;
344    for byte in s.bytes() {
345        hash ^= byte as u64;
346        hash = hash.wrapping_mul(0x100000001b3);
347    }
348    hash
349}
350
351/// Authentication middleware that checks for valid API key or JWT.
352///
353/// Checks for credentials in:
354/// 1. `Authorization: Bearer <key_or_jwt>` header
355/// 2. `X-API-Key: <key>` header
356pub async fn auth_middleware(
357    axum::Extension(config): axum::Extension<AuthConfig>,
358    request: Request<Body>,
359    next: Next,
360) -> Response {
361    // Skip if auth is disabled
362    if !config.enabled {
363        return next.run(request).await;
364    }
365
366    // Check if path is exempt
367    let path = request.uri().path();
368    if config.exempt_paths.contains(path) {
369        return next.run(request).await;
370    }
371
372    // Extract credential from headers
373    let bearer_token = extract_bearer_token(&request);
374    let api_key = extract_x_api_key(&request);
375
376    // Try Bearer token first (supports both JWT and API key)
377    if let Some(token) = bearer_token {
378        return match config.verify_bearer(&token) {
379            AuthResult::Authenticated => next.run(request).await,
380            AuthResult::InvalidCredentials => (
381                StatusCode::UNAUTHORIZED,
382                [(header::WWW_AUTHENTICATE, "Bearer")],
383                "Invalid credentials",
384            )
385                .into_response(),
386        };
387    }
388
389    // Try X-API-Key header
390    if let Some(key) = api_key {
391        if config.verify_key(&key) {
392            return next.run(request).await;
393        }
394        return (
395            StatusCode::UNAUTHORIZED,
396            [(header::WWW_AUTHENTICATE, "Bearer")],
397            "Invalid API key",
398        )
399            .into_response();
400    }
401
402    // No credentials provided
403    (
404        StatusCode::UNAUTHORIZED,
405        [(header::WWW_AUTHENTICATE, "Bearer")],
406        "API key required. Provide via 'Authorization: Bearer <key>' or 'X-API-Key' header",
407    )
408        .into_response()
409}
410
411/// Extract Bearer token from Authorization header.
412fn extract_bearer_token(request: &Request<Body>) -> Option<String> {
413    request
414        .headers()
415        .get(header::AUTHORIZATION)
416        .and_then(|h| h.to_str().ok())
417        .and_then(|s| s.strip_prefix("Bearer "))
418        .map(std::string::ToString::to_string)
419}
420
421/// Extract API key from X-API-Key header.
422fn extract_x_api_key(request: &Request<Body>) -> Option<String> {
423    request
424        .headers()
425        .get("X-API-Key")
426        .and_then(|h| h.to_str().ok())
427        .map(std::string::ToString::to_string)
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433    use axum::{
434        body::Body,
435        http::{Request, StatusCode},
436        middleware,
437        routing::get,
438        Router,
439    };
440    use tower::ServiceExt;
441
442    async fn test_handler() -> &'static str {
443        "ok"
444    }
445
446    fn test_router(config: AuthConfig) -> Router {
447        Router::new()
448            .route("/api/test", get(test_handler))
449            .route("/health", get(test_handler))
450            .layer(middleware::from_fn(auth_middleware))
451            .layer(axum::Extension(config))
452    }
453
454    #[tokio::test]
455    async fn test_auth_disabled() {
456        let config = AuthConfig::default();
457        let router = test_router(config);
458
459        let request = Request::builder()
460            .uri("/api/test")
461            .body(Body::empty())
462            .unwrap();
463
464        let response = router.oneshot(request).await.unwrap();
465        assert_eq!(response.status(), StatusCode::OK);
466    }
467
468    #[tokio::test]
469    async fn test_valid_bearer_token() {
470        let config = AuthConfig::with_api_keys(vec!["test-key-123".to_string()]);
471        let router = test_router(config);
472
473        let request = Request::builder()
474            .uri("/api/test")
475            .header("Authorization", "Bearer test-key-123")
476            .body(Body::empty())
477            .unwrap();
478
479        let response = router.oneshot(request).await.unwrap();
480        assert_eq!(response.status(), StatusCode::OK);
481    }
482
483    #[tokio::test]
484    async fn test_valid_x_api_key() {
485        let config = AuthConfig::with_api_keys(vec!["test-key-456".to_string()]);
486        let router = test_router(config);
487
488        let request = Request::builder()
489            .uri("/api/test")
490            .header("X-API-Key", "test-key-456")
491            .body(Body::empty())
492            .unwrap();
493
494        let response = router.oneshot(request).await.unwrap();
495        assert_eq!(response.status(), StatusCode::OK);
496    }
497
498    #[tokio::test]
499    async fn test_invalid_api_key() {
500        let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
501        let router = test_router(config);
502
503        let request = Request::builder()
504            .uri("/api/test")
505            .header("Authorization", "Bearer wrong-key")
506            .body(Body::empty())
507            .unwrap();
508
509        let response = router.oneshot(request).await.unwrap();
510        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
511    }
512
513    #[tokio::test]
514    async fn test_missing_api_key() {
515        let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
516        let router = test_router(config);
517
518        let request = Request::builder()
519            .uri("/api/test")
520            .body(Body::empty())
521            .unwrap();
522
523        let response = router.oneshot(request).await.unwrap();
524        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
525    }
526
527    #[tokio::test]
528    async fn test_exempt_path() {
529        let config = AuthConfig::with_api_keys(vec!["valid-key".to_string()]);
530        let router = test_router(config);
531
532        let request = Request::builder()
533            .uri("/health")
534            .body(Body::empty())
535            .unwrap();
536
537        let response = router.oneshot(request).await.unwrap();
538        assert_eq!(response.status(), StatusCode::OK);
539    }
540
541    #[tokio::test]
542    async fn test_prehashed_keys() {
543        // Hash a key manually
544        let argon2 = Argon2::default();
545        let salt = SaltString::generate(&mut OsRng);
546        let hash = argon2
547            .hash_password(b"pre-hashed-key", &salt)
548            .unwrap()
549            .to_string();
550
551        let config = AuthConfig::with_prehashed_keys(vec![hash]);
552        let router = test_router(config);
553
554        let request = Request::builder()
555            .uri("/api/test")
556            .header("Authorization", "Bearer pre-hashed-key")
557            .body(Body::empty())
558            .unwrap();
559
560        let response = router.oneshot(request).await.unwrap();
561        assert_eq!(response.status(), StatusCode::OK);
562    }
563
564    #[tokio::test]
565    async fn test_cache_hit() {
566        let config = AuthConfig::with_api_keys(vec!["cached-key".to_string()]);
567
568        // First request - populates cache
569        let router1 = test_router(config.clone());
570        let request1 = Request::builder()
571            .uri("/api/test")
572            .header("Authorization", "Bearer cached-key")
573            .body(Body::empty())
574            .unwrap();
575        let response1 = router1.oneshot(request1).await.unwrap();
576        assert_eq!(response1.status(), StatusCode::OK);
577
578        // Second request - should hit cache
579        let router2 = test_router(config);
580        let request2 = Request::builder()
581            .uri("/api/test")
582            .header("Authorization", "Bearer cached-key")
583            .body(Body::empty())
584            .unwrap();
585        let response2 = router2.oneshot(request2).await.unwrap();
586        assert_eq!(response2.status(), StatusCode::OK);
587    }
588
589    #[tokio::test]
590    async fn test_api_key_fallback_still_works() {
591        // Even without JWT feature, API key auth should work
592        let config = AuthConfig::with_api_keys(vec!["my-key".to_string()]);
593        let router = test_router(config);
594
595        let request = Request::builder()
596            .uri("/api/test")
597            .header("Authorization", "Bearer my-key")
598            .body(Body::empty())
599            .unwrap();
600
601        let response = router.oneshot(request).await.unwrap();
602        assert_eq!(response.status(), StatusCode::OK);
603    }
604
605    #[cfg(feature = "jwt")]
606    mod jwt_tests {
607        use super::*;
608
609        #[test]
610        fn test_jwt_config_creation() {
611            let config =
612                JwtConfig::new("https://auth.example.com".to_string(), "my-api".to_string());
613            assert_eq!(config.issuer, "https://auth.example.com");
614            assert_eq!(config.audience, "my-api");
615            assert!(config.public_key_pem.is_none());
616            assert_eq!(
617                config.allowed_algorithms,
618                vec![jsonwebtoken::Algorithm::RS256]
619            );
620        }
621
622        #[test]
623        fn test_jwt_validator_requires_key() {
624            let config = JwtConfig::new("issuer".to_string(), "audience".to_string());
625            let validator = JwtValidator::new(config).expect("should create");
626            let result = validator.validate_token("some.invalid.token");
627            assert!(result.is_err());
628        }
629
630        #[test]
631        fn test_token_claims_deserialization() {
632            let json = r#"{
633                "sub": "user123",
634                "email": "user@example.com",
635                "roles": ["admin", "operator"],
636                "tenant_id": "tenant1",
637                "exp": 9999999999,
638                "iss": "https://auth.example.com"
639            }"#;
640            let claims: TokenClaims = serde_json::from_str(json).unwrap();
641            assert_eq!(claims.sub, "user123");
642            assert_eq!(claims.email, Some("user@example.com".to_string()));
643            assert_eq!(claims.roles, vec!["admin", "operator"]);
644            assert_eq!(claims.tenant_id, Some("tenant1".to_string()));
645        }
646    }
647}