Skip to main content

forge_runtime/gateway/
auth.rs

1use std::sync::Arc;
2
3use axum::{
4    body::Body,
5    extract::{Request, State},
6    middleware::Next,
7    response::Response,
8};
9use forge_core::auth::Claims;
10use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
11use forge_core::function::AuthContext;
12use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode};
13use tracing::debug;
14
15use super::jwks::JwksClient;
16
17/// Authentication configuration for the runtime.
18#[derive(Debug, Clone)]
19pub struct AuthConfig {
20    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
21    pub jwt_secret: Option<String>,
22    /// JWT algorithm.
23    pub algorithm: JwtAlgorithm,
24    /// JWKS client for RSA algorithms.
25    pub jwks_client: Option<Arc<JwksClient>>,
26    /// Expected token issuer (iss claim).
27    pub issuer: Option<String>,
28    /// Expected audience (aud claim).
29    pub audience: Option<String>,
30    /// Whether to allow unauthenticated requests.
31    pub allow_anonymous: bool,
32    /// Skip signature verification (DEV MODE ONLY - NEVER USE IN PRODUCTION).
33    pub skip_verification: bool,
34}
35
36impl Default for AuthConfig {
37    fn default() -> Self {
38        Self {
39            jwt_secret: None,
40            algorithm: JwtAlgorithm::HS256,
41            jwks_client: None,
42            issuer: None,
43            audience: None,
44            allow_anonymous: true,
45            skip_verification: false,
46        }
47    }
48}
49
50impl AuthConfig {
51    /// Create auth config from forge core config.
52    pub fn from_forge_config(config: &forge_core::config::AuthConfig) -> Self {
53        let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
54
55        let jwks_client = config
56            .jwks_url
57            .as_ref()
58            .map(|url| Arc::new(JwksClient::new(url.clone(), config.jwks_cache_ttl_secs)));
59
60        Self {
61            jwt_secret: config.jwt_secret.clone(),
62            algorithm,
63            jwks_client,
64            issuer: config.jwt_issuer.clone(),
65            audience: config.jwt_audience.clone(),
66            allow_anonymous: config.allow_anonymous,
67            skip_verification: false,
68        }
69    }
70
71    /// Create a new auth config with the given HMAC secret.
72    pub fn with_secret(secret: impl Into<String>) -> Self {
73        Self {
74            jwt_secret: Some(secret.into()),
75            ..Default::default()
76        }
77    }
78
79    /// Create a dev mode config that skips signature verification.
80    /// WARNING: Only use this for development and testing!
81    pub fn dev_mode() -> Self {
82        Self {
83            jwt_secret: None,
84            algorithm: JwtAlgorithm::HS256,
85            jwks_client: None,
86            issuer: None,
87            audience: None,
88            allow_anonymous: true,
89            skip_verification: true,
90        }
91    }
92
93    /// Check if this config uses HMAC (symmetric) algorithms.
94    pub fn is_hmac(&self) -> bool {
95        matches!(
96            self.algorithm,
97            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
98        )
99    }
100
101    /// Check if this config uses RSA (asymmetric) algorithms.
102    pub fn is_rsa(&self) -> bool {
103        matches!(
104            self.algorithm,
105            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
106        )
107    }
108}
109
110/// Supported JWT algorithms.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
112pub enum JwtAlgorithm {
113    #[default]
114    HS256,
115    HS384,
116    HS512,
117    RS256,
118    RS384,
119    RS512,
120}
121
122impl From<JwtAlgorithm> for Algorithm {
123    fn from(alg: JwtAlgorithm) -> Self {
124        match alg {
125            JwtAlgorithm::HS256 => Algorithm::HS256,
126            JwtAlgorithm::HS384 => Algorithm::HS384,
127            JwtAlgorithm::HS512 => Algorithm::HS512,
128            JwtAlgorithm::RS256 => Algorithm::RS256,
129            JwtAlgorithm::RS384 => Algorithm::RS384,
130            JwtAlgorithm::RS512 => Algorithm::RS512,
131        }
132    }
133}
134
135impl From<CoreJwtAlgorithm> for JwtAlgorithm {
136    fn from(alg: CoreJwtAlgorithm) -> Self {
137        match alg {
138            CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
139            CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
140            CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
141            CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
142            CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
143            CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
144        }
145    }
146}
147
148/// Authentication middleware.
149#[derive(Clone)]
150pub struct AuthMiddleware {
151    config: Arc<AuthConfig>,
152    /// Pre-computed HMAC decoding key (for performance).
153    hmac_key: Option<DecodingKey>,
154}
155
156impl std::fmt::Debug for AuthMiddleware {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        f.debug_struct("AuthMiddleware")
159            .field("config", &self.config)
160            .field("hmac_key", &self.hmac_key.is_some())
161            .finish()
162    }
163}
164
165impl AuthMiddleware {
166    /// Create a new auth middleware.
167    pub fn new(config: AuthConfig) -> Self {
168        // Pre-compute HMAC key if using HMAC algorithm
169        let hmac_key = if config.skip_verification {
170            None
171        } else if config.is_hmac() {
172            config
173                .jwt_secret
174                .as_ref()
175                .filter(|s| !s.is_empty())
176                .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
177        } else {
178            None
179        };
180
181        Self {
182            config: Arc::new(config),
183            hmac_key,
184        }
185    }
186
187    /// Create a middleware that allows all requests (development mode).
188    /// WARNING: This skips signature verification! Never use in production.
189    pub fn permissive() -> Self {
190        Self::new(AuthConfig::dev_mode())
191    }
192
193    /// Get the config.
194    pub fn config(&self) -> &AuthConfig {
195        &self.config
196    }
197
198    /// Validate a JWT token and extract claims.
199    pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
200        if self.config.skip_verification {
201            return self.decode_without_verification(token);
202        }
203
204        if self.config.is_hmac() {
205            self.validate_hmac(token)
206        } else {
207            self.validate_rsa(token).await
208        }
209    }
210
211    /// Validate HMAC-signed token.
212    fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
213        let key = self.hmac_key.as_ref().ok_or_else(|| {
214            AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
215        })?;
216
217        self.decode_and_validate(token, key)
218    }
219
220    /// Validate RSA-signed token using JWKS.
221    async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
222        let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
223            AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
224        })?;
225
226        // Extract key ID from token header
227        let header = jsonwebtoken::decode_header(token)
228            .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
229
230        debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
231
232        // Get key from JWKS
233        let key = if let Some(kid) = header.kid {
234            jwks.get_key(&kid).await.map_err(|e| {
235                AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
236            })?
237        } else {
238            jwks.get_any_key()
239                .await
240                .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
241        };
242
243        self.decode_and_validate(token, &key)
244    }
245
246    /// Decode and validate token with the given key.
247    fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
248        let mut validation = Validation::new(self.config.algorithm.into());
249
250        // Configure validation
251        validation.validate_exp = true;
252        validation.validate_nbf = false;
253        validation.leeway = 60; // 60 seconds clock skew tolerance
254
255        // Require exp and sub claims
256        validation.set_required_spec_claims(&["exp", "sub"]);
257
258        // Validate issuer if configured
259        if let Some(ref issuer) = self.config.issuer {
260            validation.set_issuer(&[issuer]);
261        }
262
263        // Validate audience if configured
264        if let Some(ref audience) = self.config.audience {
265            validation.set_audience(&[audience]);
266        } else {
267            validation.validate_aud = false;
268        }
269
270        let token_data =
271            decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
272
273        Ok(token_data.claims)
274    }
275
276    /// Map jsonwebtoken errors to AuthError.
277    fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
278        match e.kind() {
279            jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
280            jsonwebtoken::errors::ErrorKind::InvalidSignature => {
281                AuthError::InvalidToken("Invalid signature".to_string())
282            }
283            jsonwebtoken::errors::ErrorKind::InvalidToken => {
284                AuthError::InvalidToken("Invalid token format".to_string())
285            }
286            jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
287                AuthError::InvalidToken(format!("Missing required claim: {}", claim))
288            }
289            jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
290                AuthError::InvalidToken("Invalid issuer".to_string())
291            }
292            jsonwebtoken::errors::ErrorKind::InvalidAudience => {
293                AuthError::InvalidToken("Invalid audience".to_string())
294            }
295            _ => AuthError::InvalidToken(e.to_string()),
296        }
297    }
298
299    /// Decode JWT token without signature verification (DEV MODE ONLY).
300    fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
301        let token_data =
302            dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
303                jsonwebtoken::errors::ErrorKind::InvalidToken => {
304                    AuthError::InvalidToken("Invalid token format".to_string())
305                }
306                _ => AuthError::InvalidToken(e.to_string()),
307            })?;
308
309        // Still check expiration in dev mode
310        if token_data.claims.is_expired() {
311            return Err(AuthError::TokenExpired);
312        }
313
314        Ok(token_data.claims)
315    }
316}
317
318/// Authentication errors.
319#[derive(Debug, Clone, thiserror::Error)]
320pub enum AuthError {
321    #[error("Missing authorization header")]
322    MissingHeader,
323    #[error("Invalid authorization header format")]
324    InvalidHeader,
325    #[error("Invalid token: {0}")]
326    InvalidToken(String),
327    #[error("Token expired")]
328    TokenExpired,
329}
330
331/// Extract token from request headers.
332pub fn extract_token(req: &Request<Body>) -> Option<String> {
333    req.headers()
334        .get(axum::http::header::AUTHORIZATION)
335        .and_then(|v| v.to_str().ok())
336        .filter(|header| header.starts_with("Bearer "))
337        .map(|header| header.trim_start_matches("Bearer ").trim().to_string())
338}
339
340/// Extract auth context from token (async, supports both HMAC and RSA/JWKS).
341pub async fn extract_auth_context_async(
342    token: Option<String>,
343    middleware: &AuthMiddleware,
344) -> AuthContext {
345    match token {
346        Some(token) => match middleware.validate_token_async(&token).await {
347            Ok(claims) => build_auth_context_from_claims(claims),
348            Err(e) => {
349                tracing::warn!(error = %e, "Token validation failed");
350                AuthContext::unauthenticated()
351            }
352        },
353        None => AuthContext::unauthenticated(),
354    }
355}
356
357/// Build auth context from validated claims.
358///
359/// This handles both UUID and non-UUID subjects properly:
360/// - UUID subjects: uses `authenticated()` with the parsed UUID
361/// - Non-UUID subjects: uses `authenticated_without_uuid()` and stores raw subject in claims
362pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
363    // Try to parse subject as UUID first (before moving claims)
364    let user_id = claims.user_id();
365
366    // Build custom claims with raw subject included
367    let mut custom_claims = claims.custom;
368    custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
369
370    match user_id {
371        Some(uuid) => {
372            // Subject is a valid UUID
373            AuthContext::authenticated(uuid, claims.roles, custom_claims)
374        }
375        None => {
376            // Subject is not a UUID (e.g., Firebase uid, Clerk user_xxx, email)
377            // Still authenticated, but user_id() will return None
378            AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
379        }
380    }
381}
382
383/// Authentication middleware function.
384pub async fn auth_middleware(
385    State(middleware): State<Arc<AuthMiddleware>>,
386    req: Request<Body>,
387    next: Next,
388) -> Response {
389    let token = extract_token(&req);
390    tracing::trace!(
391        token_present = token.is_some(),
392        "Auth middleware processing request"
393    );
394
395    let auth_context = extract_auth_context_async(token, &middleware).await;
396    tracing::trace!(
397        authenticated = auth_context.is_authenticated(),
398        "Auth context created"
399    );
400
401    let mut req = req;
402    req.extensions_mut().insert(auth_context);
403
404    next.run(req).await
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use jsonwebtoken::{EncodingKey, Header, encode};
411
412    fn create_test_claims(expired: bool) -> Claims {
413        use forge_core::auth::ClaimsBuilder;
414
415        let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
416
417        if expired {
418            builder = builder.duration_secs(-3600); // Expired 1 hour ago
419        } else {
420            builder = builder.duration_secs(3600); // Valid for 1 hour
421        }
422
423        builder.build().unwrap()
424    }
425
426    fn create_test_token(claims: &Claims, secret: &str) -> String {
427        encode(
428            &Header::default(),
429            claims,
430            &EncodingKey::from_secret(secret.as_bytes()),
431        )
432        .unwrap()
433    }
434
435    #[test]
436    fn test_auth_config_default() {
437        let config = AuthConfig::default();
438        assert!(config.allow_anonymous);
439        assert_eq!(config.algorithm, JwtAlgorithm::HS256);
440        assert!(!config.skip_verification);
441    }
442
443    #[test]
444    fn test_auth_config_dev_mode() {
445        let config = AuthConfig::dev_mode();
446        assert!(config.skip_verification);
447        assert!(config.allow_anonymous);
448    }
449
450    #[test]
451    fn test_auth_middleware_permissive() {
452        let middleware = AuthMiddleware::permissive();
453        assert!(middleware.config.skip_verification);
454    }
455
456    #[tokio::test]
457    async fn test_valid_token_with_correct_secret() {
458        let secret = "test-secret-key";
459        let config = AuthConfig::with_secret(secret);
460        let middleware = AuthMiddleware::new(config);
461
462        let claims = create_test_claims(false);
463        let token = create_test_token(&claims, secret);
464
465        let result = middleware.validate_token_async(&token).await;
466        assert!(result.is_ok());
467        let validated_claims = result.unwrap();
468        assert_eq!(validated_claims.sub, "test-user-id");
469    }
470
471    #[tokio::test]
472    async fn test_valid_token_with_wrong_secret() {
473        let config = AuthConfig::with_secret("correct-secret");
474        let middleware = AuthMiddleware::new(config);
475
476        let claims = create_test_claims(false);
477        let token = create_test_token(&claims, "wrong-secret");
478
479        let result = middleware.validate_token_async(&token).await;
480        assert!(result.is_err());
481        match result {
482            Err(AuthError::InvalidToken(_)) => {}
483            _ => panic!("Expected InvalidToken error"),
484        }
485    }
486
487    #[tokio::test]
488    async fn test_expired_token() {
489        let secret = "test-secret";
490        let config = AuthConfig::with_secret(secret);
491        let middleware = AuthMiddleware::new(config);
492
493        let claims = create_test_claims(true); // Expired
494        let token = create_test_token(&claims, secret);
495
496        let result = middleware.validate_token_async(&token).await;
497        assert!(result.is_err());
498        match result {
499            Err(AuthError::TokenExpired) => {}
500            _ => panic!("Expected TokenExpired error"),
501        }
502    }
503
504    #[tokio::test]
505    async fn test_tampered_token() {
506        let secret = "test-secret";
507        let config = AuthConfig::with_secret(secret);
508        let middleware = AuthMiddleware::new(config);
509
510        let claims = create_test_claims(false);
511        let mut token = create_test_token(&claims, secret);
512
513        // Tamper with the token by modifying a character in the signature
514        if let Some(last_char) = token.pop() {
515            let replacement = if last_char == 'a' { 'b' } else { 'a' };
516            token.push(replacement);
517        }
518
519        let result = middleware.validate_token_async(&token).await;
520        assert!(result.is_err());
521    }
522
523    #[tokio::test]
524    async fn test_dev_mode_skips_signature() {
525        let config = AuthConfig::dev_mode();
526        let middleware = AuthMiddleware::new(config);
527
528        // Create token with any secret
529        let claims = create_test_claims(false);
530        let token = create_test_token(&claims, "any-secret");
531
532        // Should still validate in dev mode
533        let result = middleware.validate_token_async(&token).await;
534        assert!(result.is_ok());
535    }
536
537    #[tokio::test]
538    async fn test_dev_mode_still_checks_expiration() {
539        let config = AuthConfig::dev_mode();
540        let middleware = AuthMiddleware::new(config);
541
542        let claims = create_test_claims(true); // Expired
543        let token = create_test_token(&claims, "any-secret");
544
545        let result = middleware.validate_token_async(&token).await;
546        assert!(result.is_err());
547        match result {
548            Err(AuthError::TokenExpired) => {}
549            _ => panic!("Expected TokenExpired error even in dev mode"),
550        }
551    }
552
553    #[tokio::test]
554    async fn test_invalid_token_format() {
555        let config = AuthConfig::with_secret("secret");
556        let middleware = AuthMiddleware::new(config);
557
558        let result = middleware.validate_token_async("not-a-valid-jwt").await;
559        assert!(result.is_err());
560        match result {
561            Err(AuthError::InvalidToken(_)) => {}
562            _ => panic!("Expected InvalidToken error"),
563        }
564    }
565
566    #[test]
567    fn test_algorithm_conversion() {
568        // HMAC algorithms
569        assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
570        assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
571        assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
572        // RSA algorithms
573        assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
574        assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
575        assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
576    }
577
578    #[test]
579    fn test_is_hmac_and_is_rsa() {
580        let hmac_config = AuthConfig::with_secret("test");
581        assert!(hmac_config.is_hmac());
582        assert!(!hmac_config.is_rsa());
583
584        let rsa_config = AuthConfig {
585            algorithm: JwtAlgorithm::RS256,
586            ..Default::default()
587        };
588        assert!(!rsa_config.is_hmac());
589        assert!(rsa_config.is_rsa());
590    }
591}