Skip to main content

forge_runtime/gateway/
auth.rs

1use std::sync::Arc;
2
3use axum::{
4    body::Body,
5    extract::{Request, State},
6    middleware::Next,
7    response::Response,
8};
9use forge_core::auth::Claims;
10use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
11use forge_core::function::AuthContext;
12use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode};
13use tracing::debug;
14
15use super::jwks::JwksClient;
16
17/// Authentication configuration for the runtime.
18#[derive(Debug, Clone)]
19pub struct AuthConfig {
20    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
21    pub jwt_secret: Option<String>,
22    /// JWT algorithm.
23    pub algorithm: JwtAlgorithm,
24    /// JWKS client for RSA algorithms.
25    pub jwks_client: Option<Arc<JwksClient>>,
26    /// Expected token issuer (iss claim).
27    pub issuer: Option<String>,
28    /// Expected audience (aud claim).
29    pub audience: Option<String>,
30    /// Skip signature verification (DEV MODE ONLY - NEVER USE IN PRODUCTION).
31    pub skip_verification: bool,
32}
33
34impl Default for AuthConfig {
35    fn default() -> Self {
36        Self {
37            jwt_secret: None,
38            algorithm: JwtAlgorithm::HS256,
39            jwks_client: None,
40            issuer: None,
41            audience: None,
42            skip_verification: false,
43        }
44    }
45}
46
47impl AuthConfig {
48    /// Create auth config from forge core config.
49    pub fn from_forge_config(config: &forge_core::config::AuthConfig) -> Self {
50        let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
51
52        let jwks_client = config
53            .jwks_url
54            .as_ref()
55            .map(|url| Arc::new(JwksClient::new(url.clone(), config.jwks_cache_ttl_secs)));
56
57        Self {
58            jwt_secret: config.jwt_secret.clone(),
59            algorithm,
60            jwks_client,
61            issuer: config.jwt_issuer.clone(),
62            audience: config.jwt_audience.clone(),
63            skip_verification: false,
64        }
65    }
66
67    /// Create a new auth config with the given HMAC secret.
68    pub fn with_secret(secret: impl Into<String>) -> Self {
69        Self {
70            jwt_secret: Some(secret.into()),
71            ..Default::default()
72        }
73    }
74
75    /// Create a dev mode config that skips signature verification.
76    /// WARNING: Only use this for development and testing!
77    pub fn dev_mode() -> Self {
78        Self {
79            jwt_secret: None,
80            algorithm: JwtAlgorithm::HS256,
81            jwks_client: None,
82            issuer: None,
83            audience: None,
84            skip_verification: true,
85        }
86    }
87
88    /// Check if this config uses HMAC (symmetric) algorithms.
89    pub fn is_hmac(&self) -> bool {
90        matches!(
91            self.algorithm,
92            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
93        )
94    }
95
96    /// Check if this config uses RSA (asymmetric) algorithms.
97    pub fn is_rsa(&self) -> bool {
98        matches!(
99            self.algorithm,
100            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
101        )
102    }
103}
104
105/// Supported JWT algorithms.
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
107pub enum JwtAlgorithm {
108    #[default]
109    HS256,
110    HS384,
111    HS512,
112    RS256,
113    RS384,
114    RS512,
115}
116
117impl From<JwtAlgorithm> for Algorithm {
118    fn from(alg: JwtAlgorithm) -> Self {
119        match alg {
120            JwtAlgorithm::HS256 => Algorithm::HS256,
121            JwtAlgorithm::HS384 => Algorithm::HS384,
122            JwtAlgorithm::HS512 => Algorithm::HS512,
123            JwtAlgorithm::RS256 => Algorithm::RS256,
124            JwtAlgorithm::RS384 => Algorithm::RS384,
125            JwtAlgorithm::RS512 => Algorithm::RS512,
126        }
127    }
128}
129
130impl From<CoreJwtAlgorithm> for JwtAlgorithm {
131    fn from(alg: CoreJwtAlgorithm) -> Self {
132        match alg {
133            CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
134            CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
135            CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
136            CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
137            CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
138            CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
139        }
140    }
141}
142
143/// Authentication middleware.
144#[derive(Clone)]
145pub struct AuthMiddleware {
146    config: Arc<AuthConfig>,
147    /// Pre-computed HMAC decoding key (for performance).
148    hmac_key: Option<DecodingKey>,
149}
150
151impl std::fmt::Debug for AuthMiddleware {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        f.debug_struct("AuthMiddleware")
154            .field("config", &self.config)
155            .field("hmac_key", &self.hmac_key.is_some())
156            .finish()
157    }
158}
159
160impl AuthMiddleware {
161    /// Create a new auth middleware.
162    pub fn new(config: AuthConfig) -> Self {
163        // Pre-compute HMAC key if using HMAC algorithm
164        let hmac_key = if config.skip_verification {
165            None
166        } else if config.is_hmac() {
167            config
168                .jwt_secret
169                .as_ref()
170                .filter(|s| !s.is_empty())
171                .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
172        } else {
173            None
174        };
175
176        Self {
177            config: Arc::new(config),
178            hmac_key,
179        }
180    }
181
182    /// Create a middleware that allows all requests (development mode).
183    /// WARNING: This skips signature verification! Never use in production.
184    pub fn permissive() -> Self {
185        Self::new(AuthConfig::dev_mode())
186    }
187
188    /// Get the config.
189    pub fn config(&self) -> &AuthConfig {
190        &self.config
191    }
192
193    /// Validate a JWT token and extract claims.
194    pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
195        if self.config.skip_verification {
196            return self.decode_without_verification(token);
197        }
198
199        if self.config.is_hmac() {
200            self.validate_hmac(token)
201        } else {
202            self.validate_rsa(token).await
203        }
204    }
205
206    /// Validate HMAC-signed token.
207    fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
208        let key = self.hmac_key.as_ref().ok_or_else(|| {
209            AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
210        })?;
211
212        self.decode_and_validate(token, key)
213    }
214
215    /// Validate RSA-signed token using JWKS.
216    async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
217        let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
218            AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
219        })?;
220
221        // Extract key ID from token header
222        let header = jsonwebtoken::decode_header(token)
223            .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
224
225        debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
226
227        // Get key from JWKS
228        let key = if let Some(kid) = header.kid {
229            jwks.get_key(&kid).await.map_err(|e| {
230                AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
231            })?
232        } else {
233            jwks.get_any_key()
234                .await
235                .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
236        };
237
238        self.decode_and_validate(token, &key)
239    }
240
241    /// Decode and validate token with the given key.
242    fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
243        let mut validation = Validation::new(self.config.algorithm.into());
244
245        // Configure validation
246        validation.validate_exp = true;
247        validation.validate_nbf = false;
248        validation.leeway = 60; // 60 seconds clock skew tolerance
249
250        // Require exp and sub claims
251        validation.set_required_spec_claims(&["exp", "sub"]);
252
253        // Validate issuer if configured
254        if let Some(ref issuer) = self.config.issuer {
255            validation.set_issuer(&[issuer]);
256        }
257
258        // Validate audience if configured
259        if let Some(ref audience) = self.config.audience {
260            validation.set_audience(&[audience]);
261        } else {
262            validation.validate_aud = false;
263        }
264
265        let token_data =
266            decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
267
268        Ok(token_data.claims)
269    }
270
271    /// Map jsonwebtoken errors to AuthError.
272    fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
273        match e.kind() {
274            jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
275            jsonwebtoken::errors::ErrorKind::InvalidSignature => {
276                AuthError::InvalidToken("Invalid signature".to_string())
277            }
278            jsonwebtoken::errors::ErrorKind::InvalidToken => {
279                AuthError::InvalidToken("Invalid token format".to_string())
280            }
281            jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
282                AuthError::InvalidToken(format!("Missing required claim: {}", claim))
283            }
284            jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
285                AuthError::InvalidToken("Invalid issuer".to_string())
286            }
287            jsonwebtoken::errors::ErrorKind::InvalidAudience => {
288                AuthError::InvalidToken("Invalid audience".to_string())
289            }
290            _ => AuthError::InvalidToken(e.to_string()),
291        }
292    }
293
294    /// Decode JWT token without signature verification (DEV MODE ONLY).
295    fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
296        let token_data =
297            dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
298                jsonwebtoken::errors::ErrorKind::InvalidToken => {
299                    AuthError::InvalidToken("Invalid token format".to_string())
300                }
301                _ => AuthError::InvalidToken(e.to_string()),
302            })?;
303
304        // Still check expiration in dev mode
305        if token_data.claims.is_expired() {
306            return Err(AuthError::TokenExpired);
307        }
308
309        Ok(token_data.claims)
310    }
311}
312
313/// Authentication errors.
314#[derive(Debug, Clone, thiserror::Error)]
315pub enum AuthError {
316    #[error("Missing authorization header")]
317    MissingHeader,
318    #[error("Invalid authorization header format")]
319    InvalidHeader,
320    #[error("Invalid token: {0}")]
321    InvalidToken(String),
322    #[error("Token expired")]
323    TokenExpired,
324}
325
326/// Extract token from request headers.
327pub fn extract_token(req: &Request<Body>) -> Option<String> {
328    req.headers()
329        .get(axum::http::header::AUTHORIZATION)
330        .and_then(|v| v.to_str().ok())
331        .filter(|header| header.starts_with("Bearer "))
332        .map(|header| header.trim_start_matches("Bearer ").trim().to_string())
333}
334
335/// Extract auth context from token (async, supports both HMAC and RSA/JWKS).
336pub async fn extract_auth_context_async(
337    token: Option<String>,
338    middleware: &AuthMiddleware,
339) -> AuthContext {
340    match token {
341        Some(token) => match middleware.validate_token_async(&token).await {
342            Ok(claims) => build_auth_context_from_claims(claims),
343            Err(e) => {
344                tracing::warn!(error = %e, "Token validation failed");
345                AuthContext::unauthenticated()
346            }
347        },
348        None => AuthContext::unauthenticated(),
349    }
350}
351
352/// Build auth context from validated claims.
353///
354/// This handles both UUID and non-UUID subjects properly:
355/// - UUID subjects: uses `authenticated()` with the parsed UUID
356/// - Non-UUID subjects: uses `authenticated_without_uuid()` and stores raw subject in claims
357pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
358    // Try to parse subject as UUID first (before moving claims)
359    let user_id = claims.user_id();
360
361    // Build custom claims with raw subject included
362    let mut custom_claims = claims.custom;
363    custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
364
365    match user_id {
366        Some(uuid) => {
367            // Subject is a valid UUID
368            AuthContext::authenticated(uuid, claims.roles, custom_claims)
369        }
370        None => {
371            // Subject is not a UUID (e.g., Firebase uid, Clerk user_xxx, email)
372            // Still authenticated, but user_id() will return None
373            AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
374        }
375    }
376}
377
378/// Authentication middleware function.
379pub async fn auth_middleware(
380    State(middleware): State<Arc<AuthMiddleware>>,
381    req: Request<Body>,
382    next: Next,
383) -> Response {
384    let token = extract_token(&req);
385    tracing::trace!(
386        token_present = token.is_some(),
387        "Auth middleware processing request"
388    );
389
390    let auth_context = extract_auth_context_async(token, &middleware).await;
391    tracing::trace!(
392        authenticated = auth_context.is_authenticated(),
393        "Auth context created"
394    );
395
396    let mut req = req;
397    req.extensions_mut().insert(auth_context);
398
399    next.run(req).await
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use jsonwebtoken::{EncodingKey, Header, encode};
406
407    fn create_test_claims(expired: bool) -> Claims {
408        use forge_core::auth::ClaimsBuilder;
409
410        let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
411
412        if expired {
413            builder = builder.duration_secs(-3600); // Expired 1 hour ago
414        } else {
415            builder = builder.duration_secs(3600); // Valid for 1 hour
416        }
417
418        builder.build().unwrap()
419    }
420
421    fn create_test_token(claims: &Claims, secret: &str) -> String {
422        encode(
423            &Header::default(),
424            claims,
425            &EncodingKey::from_secret(secret.as_bytes()),
426        )
427        .unwrap()
428    }
429
430    #[test]
431    fn test_auth_config_default() {
432        let config = AuthConfig::default();
433        assert_eq!(config.algorithm, JwtAlgorithm::HS256);
434        assert!(!config.skip_verification);
435    }
436
437    #[test]
438    fn test_auth_config_dev_mode() {
439        let config = AuthConfig::dev_mode();
440        assert!(config.skip_verification);
441    }
442
443    #[test]
444    fn test_auth_middleware_permissive() {
445        let middleware = AuthMiddleware::permissive();
446        assert!(middleware.config.skip_verification);
447    }
448
449    #[tokio::test]
450    async fn test_valid_token_with_correct_secret() {
451        let secret = "test-secret-key";
452        let config = AuthConfig::with_secret(secret);
453        let middleware = AuthMiddleware::new(config);
454
455        let claims = create_test_claims(false);
456        let token = create_test_token(&claims, secret);
457
458        let result = middleware.validate_token_async(&token).await;
459        assert!(result.is_ok());
460        let validated_claims = result.unwrap();
461        assert_eq!(validated_claims.sub, "test-user-id");
462    }
463
464    #[tokio::test]
465    async fn test_valid_token_with_wrong_secret() {
466        let config = AuthConfig::with_secret("correct-secret");
467        let middleware = AuthMiddleware::new(config);
468
469        let claims = create_test_claims(false);
470        let token = create_test_token(&claims, "wrong-secret");
471
472        let result = middleware.validate_token_async(&token).await;
473        assert!(result.is_err());
474        match result {
475            Err(AuthError::InvalidToken(_)) => {}
476            _ => panic!("Expected InvalidToken error"),
477        }
478    }
479
480    #[tokio::test]
481    async fn test_expired_token() {
482        let secret = "test-secret";
483        let config = AuthConfig::with_secret(secret);
484        let middleware = AuthMiddleware::new(config);
485
486        let claims = create_test_claims(true); // Expired
487        let token = create_test_token(&claims, secret);
488
489        let result = middleware.validate_token_async(&token).await;
490        assert!(result.is_err());
491        match result {
492            Err(AuthError::TokenExpired) => {}
493            _ => panic!("Expected TokenExpired error"),
494        }
495    }
496
497    #[tokio::test]
498    async fn test_tampered_token() {
499        let secret = "test-secret";
500        let config = AuthConfig::with_secret(secret);
501        let middleware = AuthMiddleware::new(config);
502
503        let claims = create_test_claims(false);
504        let mut token = create_test_token(&claims, secret);
505
506        // Tamper with the token by modifying a character in the signature
507        if let Some(last_char) = token.pop() {
508            let replacement = if last_char == 'a' { 'b' } else { 'a' };
509            token.push(replacement);
510        }
511
512        let result = middleware.validate_token_async(&token).await;
513        assert!(result.is_err());
514    }
515
516    #[tokio::test]
517    async fn test_dev_mode_skips_signature() {
518        let config = AuthConfig::dev_mode();
519        let middleware = AuthMiddleware::new(config);
520
521        // Create token with any secret
522        let claims = create_test_claims(false);
523        let token = create_test_token(&claims, "any-secret");
524
525        // Should still validate in dev mode
526        let result = middleware.validate_token_async(&token).await;
527        assert!(result.is_ok());
528    }
529
530    #[tokio::test]
531    async fn test_dev_mode_still_checks_expiration() {
532        let config = AuthConfig::dev_mode();
533        let middleware = AuthMiddleware::new(config);
534
535        let claims = create_test_claims(true); // Expired
536        let token = create_test_token(&claims, "any-secret");
537
538        let result = middleware.validate_token_async(&token).await;
539        assert!(result.is_err());
540        match result {
541            Err(AuthError::TokenExpired) => {}
542            _ => panic!("Expected TokenExpired error even in dev mode"),
543        }
544    }
545
546    #[tokio::test]
547    async fn test_invalid_token_format() {
548        let config = AuthConfig::with_secret("secret");
549        let middleware = AuthMiddleware::new(config);
550
551        let result = middleware.validate_token_async("not-a-valid-jwt").await;
552        assert!(result.is_err());
553        match result {
554            Err(AuthError::InvalidToken(_)) => {}
555            _ => panic!("Expected InvalidToken error"),
556        }
557    }
558
559    #[test]
560    fn test_algorithm_conversion() {
561        // HMAC algorithms
562        assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
563        assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
564        assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
565        // RSA algorithms
566        assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
567        assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
568        assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
569    }
570
571    #[test]
572    fn test_is_hmac_and_is_rsa() {
573        let hmac_config = AuthConfig::with_secret("test");
574        assert!(hmac_config.is_hmac());
575        assert!(!hmac_config.is_rsa());
576
577        let rsa_config = AuthConfig {
578            algorithm: JwtAlgorithm::RS256,
579            ..Default::default()
580        };
581        assert!(!rsa_config.is_hmac());
582        assert!(rsa_config.is_rsa());
583    }
584}