forge_runtime/gateway/
auth.rs

1use std::sync::Arc;
2
3use axum::{
4    body::Body,
5    extract::{Request, State},
6    middleware::Next,
7    response::Response,
8};
9use forge_core::auth::Claims;
10use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
11use forge_core::function::AuthContext;
12use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode};
13use tracing::debug;
14
15use super::jwks::JwksClient;
16
17/// Authentication configuration for the runtime.
18#[derive(Debug, Clone)]
19pub struct AuthConfig {
20    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
21    pub jwt_secret: Option<String>,
22    /// JWT algorithm.
23    pub algorithm: JwtAlgorithm,
24    /// JWKS client for RSA algorithms.
25    pub jwks_client: Option<Arc<JwksClient>>,
26    /// Expected token issuer (iss claim).
27    pub issuer: Option<String>,
28    /// Expected audience (aud claim).
29    pub audience: Option<String>,
30    /// Whether to allow unauthenticated requests.
31    pub allow_anonymous: bool,
32    /// Skip signature verification (DEV MODE ONLY - NEVER USE IN PRODUCTION).
33    pub skip_verification: bool,
34}
35
36impl Default for AuthConfig {
37    fn default() -> Self {
38        Self {
39            jwt_secret: None,
40            algorithm: JwtAlgorithm::HS256,
41            jwks_client: None,
42            issuer: None,
43            audience: None,
44            allow_anonymous: true,
45            skip_verification: false,
46        }
47    }
48}
49
50impl AuthConfig {
51    /// Create auth config from forge core config.
52    pub fn from_forge_config(config: &forge_core::config::AuthConfig) -> Self {
53        let algorithm = JwtAlgorithm::from(config.algorithm);
54
55        let jwks_client = config
56            .jwks_url
57            .as_ref()
58            .map(|url| Arc::new(JwksClient::new(url.clone(), config.jwks_cache_ttl_secs)));
59
60        Self {
61            jwt_secret: config.jwt_secret.clone(),
62            algorithm,
63            jwks_client,
64            issuer: config.issuer.clone(),
65            audience: config.audience.clone(),
66            allow_anonymous: config.allow_anonymous,
67            skip_verification: false,
68        }
69    }
70
71    /// Create a new auth config with the given HMAC secret.
72    pub fn with_secret(secret: impl Into<String>) -> Self {
73        Self {
74            jwt_secret: Some(secret.into()),
75            ..Default::default()
76        }
77    }
78
79    /// Create a dev mode config that skips signature verification.
80    /// WARNING: Only use this for development and testing!
81    pub fn dev_mode() -> Self {
82        Self {
83            jwt_secret: None,
84            algorithm: JwtAlgorithm::HS256,
85            jwks_client: None,
86            issuer: None,
87            audience: None,
88            allow_anonymous: true,
89            skip_verification: true,
90        }
91    }
92
93    /// Check if this config uses HMAC (symmetric) algorithms.
94    pub fn is_hmac(&self) -> bool {
95        matches!(
96            self.algorithm,
97            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
98        )
99    }
100
101    /// Check if this config uses RSA (asymmetric) algorithms.
102    pub fn is_rsa(&self) -> bool {
103        matches!(
104            self.algorithm,
105            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
106        )
107    }
108}
109
110/// Supported JWT algorithms.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
112pub enum JwtAlgorithm {
113    #[default]
114    HS256,
115    HS384,
116    HS512,
117    RS256,
118    RS384,
119    RS512,
120}
121
122impl From<JwtAlgorithm> for Algorithm {
123    fn from(alg: JwtAlgorithm) -> Self {
124        match alg {
125            JwtAlgorithm::HS256 => Algorithm::HS256,
126            JwtAlgorithm::HS384 => Algorithm::HS384,
127            JwtAlgorithm::HS512 => Algorithm::HS512,
128            JwtAlgorithm::RS256 => Algorithm::RS256,
129            JwtAlgorithm::RS384 => Algorithm::RS384,
130            JwtAlgorithm::RS512 => Algorithm::RS512,
131        }
132    }
133}
134
135impl From<CoreJwtAlgorithm> for JwtAlgorithm {
136    fn from(alg: CoreJwtAlgorithm) -> Self {
137        match alg {
138            CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
139            CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
140            CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
141            CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
142            CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
143            CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
144        }
145    }
146}
147
148/// Authentication middleware.
149#[derive(Clone)]
150pub struct AuthMiddleware {
151    config: Arc<AuthConfig>,
152    /// Pre-computed HMAC decoding key (for performance).
153    hmac_key: Option<DecodingKey>,
154}
155
156impl std::fmt::Debug for AuthMiddleware {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        f.debug_struct("AuthMiddleware")
159            .field("config", &self.config)
160            .field("hmac_key", &self.hmac_key.is_some())
161            .finish()
162    }
163}
164
165impl AuthMiddleware {
166    /// Create a new auth middleware.
167    pub fn new(config: AuthConfig) -> Self {
168        // Pre-compute HMAC key if using HMAC algorithm
169        let hmac_key = if config.skip_verification {
170            None
171        } else if config.is_hmac() {
172            config
173                .jwt_secret
174                .as_ref()
175                .filter(|s| !s.is_empty())
176                .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
177        } else {
178            None
179        };
180
181        Self {
182            config: Arc::new(config),
183            hmac_key,
184        }
185    }
186
187    /// Create a middleware that allows all requests (development mode).
188    /// WARNING: This skips signature verification! Never use in production.
189    pub fn permissive() -> Self {
190        Self::new(AuthConfig::dev_mode())
191    }
192
193    /// Get the config.
194    pub fn config(&self) -> &AuthConfig {
195        &self.config
196    }
197
198    /// Validate a JWT token and extract claims (sync version for HMAC).
199    /// For RSA algorithms, use `validate_token_async`.
200    pub fn validate_token(&self, token: &str) -> Result<Claims, AuthError> {
201        if self.config.skip_verification {
202            return self.decode_without_verification(token);
203        }
204
205        if self.config.is_hmac() {
206            self.validate_hmac(token)
207        } else {
208            // RSA requires async - return error for sync call
209            Err(AuthError::InvalidToken(
210                "RSA validation requires async. Use validate_token_async.".to_string(),
211            ))
212        }
213    }
214
215    /// Validate a JWT token and extract claims (async version, supports RSA).
216    pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
217        if self.config.skip_verification {
218            return self.decode_without_verification(token);
219        }
220
221        if self.config.is_hmac() {
222            self.validate_hmac(token)
223        } else {
224            self.validate_rsa(token).await
225        }
226    }
227
228    /// Validate HMAC-signed token.
229    fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
230        let key = self.hmac_key.as_ref().ok_or_else(|| {
231            AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
232        })?;
233
234        self.decode_and_validate(token, key)
235    }
236
237    /// Validate RSA-signed token using JWKS.
238    async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
239        let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
240            AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
241        })?;
242
243        // Extract key ID from token header
244        let header = jsonwebtoken::decode_header(token)
245            .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
246
247        debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
248
249        // Get key from JWKS
250        let key = if let Some(kid) = header.kid {
251            jwks.get_key(&kid).await.map_err(|e| {
252                AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
253            })?
254        } else {
255            jwks.get_any_key()
256                .await
257                .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
258        };
259
260        self.decode_and_validate(token, &key)
261    }
262
263    /// Decode and validate token with the given key.
264    fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
265        let mut validation = Validation::new(self.config.algorithm.into());
266
267        // Configure validation
268        validation.validate_exp = true;
269        validation.validate_nbf = false;
270        validation.leeway = 60; // 60 seconds clock skew tolerance
271
272        // Require exp and sub claims
273        validation.set_required_spec_claims(&["exp", "sub"]);
274
275        // Validate issuer if configured
276        if let Some(ref issuer) = self.config.issuer {
277            validation.set_issuer(&[issuer]);
278        }
279
280        // Validate audience if configured
281        if let Some(ref audience) = self.config.audience {
282            validation.set_audience(&[audience]);
283        } else {
284            validation.validate_aud = false;
285        }
286
287        let token_data =
288            decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
289
290        Ok(token_data.claims)
291    }
292
293    /// Map jsonwebtoken errors to AuthError.
294    fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
295        match e.kind() {
296            jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
297            jsonwebtoken::errors::ErrorKind::InvalidSignature => {
298                AuthError::InvalidToken("Invalid signature".to_string())
299            }
300            jsonwebtoken::errors::ErrorKind::InvalidToken => {
301                AuthError::InvalidToken("Invalid token format".to_string())
302            }
303            jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
304                AuthError::InvalidToken(format!("Missing required claim: {}", claim))
305            }
306            jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
307                AuthError::InvalidToken("Invalid issuer".to_string())
308            }
309            jsonwebtoken::errors::ErrorKind::InvalidAudience => {
310                AuthError::InvalidToken("Invalid audience".to_string())
311            }
312            _ => AuthError::InvalidToken(e.to_string()),
313        }
314    }
315
316    /// Decode JWT token without signature verification (DEV MODE ONLY).
317    fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
318        let token_data =
319            dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
320                jsonwebtoken::errors::ErrorKind::InvalidToken => {
321                    AuthError::InvalidToken("Invalid token format".to_string())
322                }
323                _ => AuthError::InvalidToken(e.to_string()),
324            })?;
325
326        // Still check expiration in dev mode
327        if token_data.claims.is_expired() {
328            return Err(AuthError::TokenExpired);
329        }
330
331        Ok(token_data.claims)
332    }
333}
334
335/// Authentication errors.
336#[derive(Debug, Clone, thiserror::Error)]
337pub enum AuthError {
338    #[error("Missing authorization header")]
339    MissingHeader,
340    #[error("Invalid authorization header format")]
341    InvalidHeader,
342    #[error("Invalid token: {0}")]
343    InvalidToken(String),
344    #[error("Token expired")]
345    TokenExpired,
346}
347
348/// Extract token from request headers.
349pub fn extract_token(req: &Request<Body>) -> Option<String> {
350    req.headers()
351        .get(axum::http::header::AUTHORIZATION)
352        .and_then(|v| v.to_str().ok())
353        .filter(|header| header.starts_with("Bearer "))
354        .map(|header| header.trim_start_matches("Bearer ").trim().to_string())
355}
356
357/// Extract auth context from token (async, supports both HMAC and RSA/JWKS).
358pub async fn extract_auth_context_async(
359    token: Option<String>,
360    middleware: &AuthMiddleware,
361) -> AuthContext {
362    match token {
363        Some(token) => match middleware.validate_token_async(&token).await {
364            Ok(claims) => build_auth_context_from_claims(claims),
365            Err(e) => {
366                tracing::warn!(error = %e, "Token validation failed");
367                AuthContext::unauthenticated()
368            }
369        },
370        None => AuthContext::unauthenticated(),
371    }
372}
373
374/// Build auth context from validated claims.
375///
376/// This handles both UUID and non-UUID subjects properly:
377/// - UUID subjects: uses `authenticated()` with the parsed UUID
378/// - Non-UUID subjects: uses `authenticated_without_uuid()` and stores raw subject in claims
379pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
380    // Try to parse subject as UUID first (before moving claims)
381    let user_id = claims.user_id();
382
383    // Build custom claims with raw subject included
384    let mut custom_claims = claims.custom;
385    custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
386
387    match user_id {
388        Some(uuid) => {
389            // Subject is a valid UUID
390            AuthContext::authenticated(uuid, claims.roles, custom_claims)
391        }
392        None => {
393            // Subject is not a UUID (e.g., Firebase uid, Clerk user_xxx, email)
394            // Still authenticated, but user_id() will return None
395            AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
396        }
397    }
398}
399
400/// Authentication middleware function.
401pub async fn auth_middleware(
402    State(middleware): State<Arc<AuthMiddleware>>,
403    req: Request<Body>,
404    next: Next,
405) -> Response {
406    let token = extract_token(&req);
407    tracing::trace!(
408        token_present = token.is_some(),
409        "Auth middleware processing request"
410    );
411
412    let auth_context = extract_auth_context_async(token, &middleware).await;
413    tracing::trace!(
414        authenticated = auth_context.is_authenticated(),
415        "Auth context created"
416    );
417
418    let mut req = req;
419    req.extensions_mut().insert(auth_context);
420
421    next.run(req).await
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use jsonwebtoken::{EncodingKey, Header, encode};
428
429    fn create_test_claims(expired: bool) -> Claims {
430        use forge_core::auth::ClaimsBuilder;
431
432        let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
433
434        if expired {
435            builder = builder.duration_secs(-3600); // Expired 1 hour ago
436        } else {
437            builder = builder.duration_secs(3600); // Valid for 1 hour
438        }
439
440        builder.build().unwrap()
441    }
442
443    fn create_test_token(claims: &Claims, secret: &str) -> String {
444        encode(
445            &Header::default(),
446            claims,
447            &EncodingKey::from_secret(secret.as_bytes()),
448        )
449        .unwrap()
450    }
451
452    #[test]
453    fn test_auth_config_default() {
454        let config = AuthConfig::default();
455        assert!(config.allow_anonymous);
456        assert_eq!(config.algorithm, JwtAlgorithm::HS256);
457        assert!(!config.skip_verification);
458    }
459
460    #[test]
461    fn test_auth_config_dev_mode() {
462        let config = AuthConfig::dev_mode();
463        assert!(config.skip_verification);
464        assert!(config.allow_anonymous);
465    }
466
467    #[test]
468    fn test_auth_middleware_permissive() {
469        let middleware = AuthMiddleware::permissive();
470        assert!(middleware.config.skip_verification);
471    }
472
473    #[test]
474    fn test_valid_token_with_correct_secret() {
475        let secret = "test-secret-key";
476        let config = AuthConfig::with_secret(secret);
477        let middleware = AuthMiddleware::new(config);
478
479        let claims = create_test_claims(false);
480        let token = create_test_token(&claims, secret);
481
482        let result = middleware.validate_token(&token);
483        assert!(result.is_ok());
484        let validated_claims = result.unwrap();
485        assert_eq!(validated_claims.sub, "test-user-id");
486    }
487
488    #[test]
489    fn test_valid_token_with_wrong_secret() {
490        let config = AuthConfig::with_secret("correct-secret");
491        let middleware = AuthMiddleware::new(config);
492
493        let claims = create_test_claims(false);
494        let token = create_test_token(&claims, "wrong-secret");
495
496        let result = middleware.validate_token(&token);
497        assert!(result.is_err());
498        match result {
499            Err(AuthError::InvalidToken(_)) => {}
500            _ => panic!("Expected InvalidToken error"),
501        }
502    }
503
504    #[test]
505    fn test_expired_token() {
506        let secret = "test-secret";
507        let config = AuthConfig::with_secret(secret);
508        let middleware = AuthMiddleware::new(config);
509
510        let claims = create_test_claims(true); // Expired
511        let token = create_test_token(&claims, secret);
512
513        let result = middleware.validate_token(&token);
514        assert!(result.is_err());
515        match result {
516            Err(AuthError::TokenExpired) => {}
517            _ => panic!("Expected TokenExpired error"),
518        }
519    }
520
521    #[test]
522    fn test_tampered_token() {
523        let secret = "test-secret";
524        let config = AuthConfig::with_secret(secret);
525        let middleware = AuthMiddleware::new(config);
526
527        let claims = create_test_claims(false);
528        let mut token = create_test_token(&claims, secret);
529
530        // Tamper with the token by modifying a character in the signature
531        if let Some(last_char) = token.pop() {
532            let replacement = if last_char == 'a' { 'b' } else { 'a' };
533            token.push(replacement);
534        }
535
536        let result = middleware.validate_token(&token);
537        assert!(result.is_err());
538    }
539
540    #[test]
541    fn test_dev_mode_skips_signature() {
542        let config = AuthConfig::dev_mode();
543        let middleware = AuthMiddleware::new(config);
544
545        // Create token with any secret
546        let claims = create_test_claims(false);
547        let token = create_test_token(&claims, "any-secret");
548
549        // Should still validate in dev mode
550        let result = middleware.validate_token(&token);
551        assert!(result.is_ok());
552    }
553
554    #[test]
555    fn test_dev_mode_still_checks_expiration() {
556        let config = AuthConfig::dev_mode();
557        let middleware = AuthMiddleware::new(config);
558
559        let claims = create_test_claims(true); // Expired
560        let token = create_test_token(&claims, "any-secret");
561
562        let result = middleware.validate_token(&token);
563        assert!(result.is_err());
564        match result {
565            Err(AuthError::TokenExpired) => {}
566            _ => panic!("Expected TokenExpired error even in dev mode"),
567        }
568    }
569
570    #[test]
571    fn test_invalid_token_format() {
572        let config = AuthConfig::with_secret("secret");
573        let middleware = AuthMiddleware::new(config);
574
575        let result = middleware.validate_token("not-a-valid-jwt");
576        assert!(result.is_err());
577        match result {
578            Err(AuthError::InvalidToken(_)) => {}
579            _ => panic!("Expected InvalidToken error"),
580        }
581    }
582
583    #[test]
584    fn test_algorithm_conversion() {
585        // HMAC algorithms
586        assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
587        assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
588        assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
589        // RSA algorithms
590        assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
591        assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
592        assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
593    }
594
595    #[test]
596    fn test_is_hmac_and_is_rsa() {
597        let hmac_config = AuthConfig::with_secret("test");
598        assert!(hmac_config.is_hmac());
599        assert!(!hmac_config.is_rsa());
600
601        let rsa_config = AuthConfig {
602            algorithm: JwtAlgorithm::RS256,
603            ..Default::default()
604        };
605        assert!(!rsa_config.is_hmac());
606        assert!(rsa_config.is_rsa());
607    }
608}