Skip to main content

forge_runtime/gateway/
auth.rs

1use std::sync::Arc;
2
3use axum::{
4    body::Body,
5    extract::{Request, State},
6    http::StatusCode,
7    middleware::Next,
8    response::{IntoResponse, Response},
9};
10use forge_core::auth::Claims;
11use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
12use forge_core::function::AuthContext;
13use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode};
14use tracing::debug;
15
16use super::jwks::JwksClient;
17
18/// Authentication configuration for the runtime.
19#[derive(Debug, Clone)]
20pub struct AuthConfig {
21    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
22    pub jwt_secret: Option<String>,
23    /// JWT algorithm.
24    pub algorithm: JwtAlgorithm,
25    /// JWKS client for RSA algorithms.
26    pub jwks_client: Option<Arc<JwksClient>>,
27    /// Expected token issuer (iss claim).
28    pub issuer: Option<String>,
29    /// Expected audience (aud claim).
30    pub audience: Option<String>,
31    /// Skip signature verification (DEV MODE ONLY - NEVER USE IN PRODUCTION).
32    pub skip_verification: bool,
33}
34
35impl Default for AuthConfig {
36    fn default() -> Self {
37        Self {
38            jwt_secret: None,
39            algorithm: JwtAlgorithm::HS256,
40            jwks_client: None,
41            issuer: None,
42            audience: None,
43            skip_verification: false,
44        }
45    }
46}
47
48impl AuthConfig {
49    /// Create auth config from forge core config.
50    pub fn from_forge_config(
51        config: &forge_core::config::AuthConfig,
52    ) -> Result<Self, super::jwks::JwksError> {
53        let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
54
55        let jwks_client = config
56            .jwks_url
57            .as_ref()
58            .map(|url| JwksClient::new(url.clone(), config.jwks_cache_ttl_secs).map(Arc::new))
59            .transpose()?;
60
61        Ok(Self {
62            jwt_secret: config.jwt_secret.clone(),
63            algorithm,
64            jwks_client,
65            issuer: config.jwt_issuer.clone(),
66            audience: config.jwt_audience.clone(),
67            skip_verification: false,
68        })
69    }
70
71    /// Create a new auth config with the given HMAC secret.
72    pub fn with_secret(secret: impl Into<String>) -> Self {
73        Self {
74            jwt_secret: Some(secret.into()),
75            ..Default::default()
76        }
77    }
78
79    /// Create a dev mode config that skips signature verification.
80    /// WARNING: Only use this for development and testing!
81    pub fn dev_mode() -> Self {
82        Self {
83            jwt_secret: None,
84            algorithm: JwtAlgorithm::HS256,
85            jwks_client: None,
86            issuer: None,
87            audience: None,
88            skip_verification: true,
89        }
90    }
91
92    /// Check if this config uses HMAC (symmetric) algorithms.
93    pub fn is_hmac(&self) -> bool {
94        matches!(
95            self.algorithm,
96            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
97        )
98    }
99
100    /// Check if this config uses RSA (asymmetric) algorithms.
101    pub fn is_rsa(&self) -> bool {
102        matches!(
103            self.algorithm,
104            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
105        )
106    }
107}
108
109/// Supported JWT algorithms.
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub enum JwtAlgorithm {
112    #[default]
113    HS256,
114    HS384,
115    HS512,
116    RS256,
117    RS384,
118    RS512,
119}
120
121impl From<JwtAlgorithm> for Algorithm {
122    fn from(alg: JwtAlgorithm) -> Self {
123        match alg {
124            JwtAlgorithm::HS256 => Algorithm::HS256,
125            JwtAlgorithm::HS384 => Algorithm::HS384,
126            JwtAlgorithm::HS512 => Algorithm::HS512,
127            JwtAlgorithm::RS256 => Algorithm::RS256,
128            JwtAlgorithm::RS384 => Algorithm::RS384,
129            JwtAlgorithm::RS512 => Algorithm::RS512,
130        }
131    }
132}
133
134impl From<CoreJwtAlgorithm> for JwtAlgorithm {
135    fn from(alg: CoreJwtAlgorithm) -> Self {
136        match alg {
137            CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
138            CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
139            CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
140            CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
141            CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
142            CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
143        }
144    }
145}
146
147/// Authentication middleware.
148#[derive(Clone)]
149pub struct AuthMiddleware {
150    config: Arc<AuthConfig>,
151    /// Pre-computed HMAC decoding key (for performance).
152    hmac_key: Option<DecodingKey>,
153}
154
155impl std::fmt::Debug for AuthMiddleware {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        f.debug_struct("AuthMiddleware")
158            .field("config", &self.config)
159            .field("hmac_key", &self.hmac_key.is_some())
160            .finish()
161    }
162}
163
164impl AuthMiddleware {
165    /// Create a new auth middleware.
166    pub fn new(config: AuthConfig) -> Self {
167        if config.skip_verification {
168            tracing::warn!("JWT signature verification is DISABLED. Do not use in production.");
169        }
170
171        // Pre-compute HMAC key if using HMAC algorithm
172        let hmac_key = if config.skip_verification {
173            None
174        } else if config.is_hmac() {
175            config
176                .jwt_secret
177                .as_ref()
178                .filter(|s| !s.is_empty())
179                .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
180        } else {
181            None
182        };
183
184        Self {
185            config: Arc::new(config),
186            hmac_key,
187        }
188    }
189
190    /// Create a middleware that allows all requests (development mode).
191    /// WARNING: This skips signature verification! Never use in production.
192    pub fn permissive() -> Self {
193        Self::new(AuthConfig::dev_mode())
194    }
195
196    /// Get the config.
197    pub fn config(&self) -> &AuthConfig {
198        &self.config
199    }
200
201    /// Validate a JWT token and extract claims.
202    pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
203        if self.config.skip_verification {
204            return self.decode_without_verification(token);
205        }
206
207        if self.config.is_hmac() {
208            self.validate_hmac(token)
209        } else {
210            self.validate_rsa(token).await
211        }
212    }
213
214    /// Validate HMAC-signed token.
215    fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
216        let key = self.hmac_key.as_ref().ok_or_else(|| {
217            AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
218        })?;
219
220        self.decode_and_validate(token, key)
221    }
222
223    /// Validate RSA-signed token using JWKS.
224    async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
225        let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
226            AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
227        })?;
228
229        // Extract key ID from token header
230        let header = jsonwebtoken::decode_header(token)
231            .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
232
233        debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
234
235        // Get key from JWKS
236        let key = if let Some(kid) = header.kid {
237            jwks.get_key(&kid).await.map_err(|e| {
238                AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
239            })?
240        } else {
241            jwks.get_any_key()
242                .await
243                .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
244        };
245
246        self.decode_and_validate(token, &key)
247    }
248
249    /// Decode and validate token with the given key.
250    fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
251        let mut validation = Validation::new(self.config.algorithm.into());
252
253        // Configure validation
254        validation.validate_exp = true;
255        validation.validate_nbf = false;
256        validation.leeway = 60; // 60 seconds clock skew tolerance
257
258        // Require exp and sub claims
259        validation.set_required_spec_claims(&["exp", "sub"]);
260
261        // Validate issuer if configured
262        if let Some(ref issuer) = self.config.issuer {
263            validation.set_issuer(&[issuer]);
264        }
265
266        // Validate audience if configured
267        if let Some(ref audience) = self.config.audience {
268            validation.set_audience(&[audience]);
269        } else {
270            validation.validate_aud = false;
271        }
272
273        let token_data =
274            decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
275
276        Ok(token_data.claims)
277    }
278
279    /// Map jsonwebtoken errors to AuthError.
280    fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
281        match e.kind() {
282            jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
283            jsonwebtoken::errors::ErrorKind::InvalidSignature => {
284                AuthError::InvalidToken("Invalid signature".to_string())
285            }
286            jsonwebtoken::errors::ErrorKind::InvalidToken => {
287                AuthError::InvalidToken("Invalid token format".to_string())
288            }
289            jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
290                AuthError::InvalidToken(format!("Missing required claim: {}", claim))
291            }
292            jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
293                AuthError::InvalidToken("Invalid issuer".to_string())
294            }
295            jsonwebtoken::errors::ErrorKind::InvalidAudience => {
296                AuthError::InvalidToken("Invalid audience".to_string())
297            }
298            _ => AuthError::InvalidToken(e.to_string()),
299        }
300    }
301
302    /// Decode JWT token without signature verification (DEV MODE ONLY).
303    fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
304        let token_data =
305            dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
306                jsonwebtoken::errors::ErrorKind::InvalidToken => {
307                    AuthError::InvalidToken("Invalid token format".to_string())
308                }
309                _ => AuthError::InvalidToken(e.to_string()),
310            })?;
311
312        // Still check expiration in dev mode
313        if token_data.claims.is_expired() {
314            return Err(AuthError::TokenExpired);
315        }
316
317        Ok(token_data.claims)
318    }
319}
320
321/// Authentication errors.
322#[derive(Debug, Clone, thiserror::Error)]
323pub enum AuthError {
324    #[error("Missing authorization header")]
325    MissingHeader,
326    #[error("Invalid authorization header format")]
327    InvalidHeader,
328    #[error("Invalid token: {0}")]
329    InvalidToken(String),
330    #[error("Token expired")]
331    TokenExpired,
332}
333
334/// Extract token from request headers.
335pub fn extract_token(req: &Request<Body>) -> Result<Option<String>, AuthError> {
336    let Some(header_value) = req.headers().get(axum::http::header::AUTHORIZATION) else {
337        return Ok(None);
338    };
339
340    let header = header_value
341        .to_str()
342        .map_err(|_| AuthError::InvalidHeader)?;
343    let token = header
344        .strip_prefix("Bearer ")
345        .ok_or(AuthError::InvalidHeader)?
346        .trim();
347
348    if token.is_empty() {
349        return Err(AuthError::InvalidHeader);
350    }
351
352    Ok(Some(token.to_string()))
353}
354
355/// Extract auth context from token (async, supports both HMAC and RSA/JWKS).
356pub async fn extract_auth_context_async(
357    token: Option<String>,
358    middleware: &AuthMiddleware,
359) -> Result<AuthContext, AuthError> {
360    match token {
361        Some(token) => middleware
362            .validate_token_async(&token)
363            .await
364            .map(build_auth_context_from_claims),
365        None => Ok(AuthContext::unauthenticated()),
366    }
367}
368
369/// Build auth context from validated claims.
370///
371/// This handles both UUID and non-UUID subjects properly:
372/// - UUID subjects: uses `authenticated()` with the parsed UUID
373/// - Non-UUID subjects: uses `authenticated_without_uuid()` and stores raw subject in claims
374pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
375    // Try to parse subject as UUID first (before moving claims)
376    let user_id = claims.user_id();
377
378    // Build custom claims with raw subject included
379    let mut custom_claims = claims.custom;
380    custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
381
382    match user_id {
383        Some(uuid) => {
384            // Subject is a valid UUID
385            AuthContext::authenticated(uuid, claims.roles, custom_claims)
386        }
387        None => {
388            // Subject is not a UUID (e.g., Firebase uid, Clerk user_xxx, email)
389            // Still authenticated, but user_id() will return None
390            AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
391        }
392    }
393}
394
395/// Authentication middleware function.
396pub async fn auth_middleware(
397    State(middleware): State<Arc<AuthMiddleware>>,
398    req: Request<Body>,
399    next: Next,
400) -> Response {
401    let token = match extract_token(&req) {
402        Ok(token) => token,
403        Err(e) => {
404            tracing::warn!(error = %e, "Invalid authorization header");
405            return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response();
406        }
407    };
408    tracing::trace!(
409        token_present = token.is_some(),
410        "Auth middleware processing request"
411    );
412
413    let auth_context = match extract_auth_context_async(token, &middleware).await {
414        Ok(auth_context) => auth_context,
415        Err(e) => {
416            tracing::warn!(error = %e, "Token validation failed");
417            return (StatusCode::UNAUTHORIZED, "Invalid authentication token").into_response();
418        }
419    };
420    tracing::trace!(
421        authenticated = auth_context.is_authenticated(),
422        "Auth context created"
423    );
424
425    let mut req = req;
426    req.extensions_mut().insert(auth_context);
427
428    next.run(req).await
429}
430
431#[cfg(test)]
432#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
433mod tests {
434    use super::*;
435    use jsonwebtoken::{EncodingKey, Header, encode};
436
437    fn create_test_claims(expired: bool) -> Claims {
438        use forge_core::auth::ClaimsBuilder;
439
440        let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
441
442        if expired {
443            builder = builder.duration_secs(-3600); // Expired 1 hour ago
444        } else {
445            builder = builder.duration_secs(3600); // Valid for 1 hour
446        }
447
448        builder.build().unwrap()
449    }
450
451    fn create_test_token(claims: &Claims, secret: &str) -> String {
452        encode(
453            &Header::default(),
454            claims,
455            &EncodingKey::from_secret(secret.as_bytes()),
456        )
457        .unwrap()
458    }
459
460    #[test]
461    fn test_auth_config_default() {
462        let config = AuthConfig::default();
463        assert_eq!(config.algorithm, JwtAlgorithm::HS256);
464        assert!(!config.skip_verification);
465    }
466
467    #[test]
468    fn test_auth_config_dev_mode() {
469        let config = AuthConfig::dev_mode();
470        assert!(config.skip_verification);
471    }
472
473    #[test]
474    fn test_auth_middleware_permissive() {
475        let middleware = AuthMiddleware::permissive();
476        assert!(middleware.config.skip_verification);
477    }
478
479    #[tokio::test]
480    async fn test_valid_token_with_correct_secret() {
481        let secret = "test-secret-key";
482        let config = AuthConfig::with_secret(secret);
483        let middleware = AuthMiddleware::new(config);
484
485        let claims = create_test_claims(false);
486        let token = create_test_token(&claims, secret);
487
488        let result = middleware.validate_token_async(&token).await;
489        assert!(result.is_ok());
490        let validated_claims = result.unwrap();
491        assert_eq!(validated_claims.sub, "test-user-id");
492    }
493
494    #[tokio::test]
495    async fn test_valid_token_with_wrong_secret() {
496        let config = AuthConfig::with_secret("correct-secret");
497        let middleware = AuthMiddleware::new(config);
498
499        let claims = create_test_claims(false);
500        let token = create_test_token(&claims, "wrong-secret");
501
502        let result = middleware.validate_token_async(&token).await;
503        assert!(result.is_err());
504        match result {
505            Err(AuthError::InvalidToken(_)) => {}
506            _ => panic!("Expected InvalidToken error"),
507        }
508    }
509
510    #[tokio::test]
511    async fn test_expired_token() {
512        let secret = "test-secret";
513        let config = AuthConfig::with_secret(secret);
514        let middleware = AuthMiddleware::new(config);
515
516        let claims = create_test_claims(true); // Expired
517        let token = create_test_token(&claims, secret);
518
519        let result = middleware.validate_token_async(&token).await;
520        assert!(result.is_err());
521        match result {
522            Err(AuthError::TokenExpired) => {}
523            _ => panic!("Expected TokenExpired error"),
524        }
525    }
526
527    #[tokio::test]
528    async fn test_tampered_token() {
529        let secret = "test-secret";
530        let config = AuthConfig::with_secret(secret);
531        let middleware = AuthMiddleware::new(config);
532
533        let claims = create_test_claims(false);
534        let mut token = create_test_token(&claims, secret);
535
536        // Tamper with the token by modifying a character in the signature
537        if let Some(last_char) = token.pop() {
538            let replacement = if last_char == 'a' { 'b' } else { 'a' };
539            token.push(replacement);
540        }
541
542        let result = middleware.validate_token_async(&token).await;
543        assert!(result.is_err());
544    }
545
546    #[tokio::test]
547    async fn test_dev_mode_skips_signature() {
548        let config = AuthConfig::dev_mode();
549        let middleware = AuthMiddleware::new(config);
550
551        // Create token with any secret
552        let claims = create_test_claims(false);
553        let token = create_test_token(&claims, "any-secret");
554
555        // Should still validate in dev mode
556        let result = middleware.validate_token_async(&token).await;
557        assert!(result.is_ok());
558    }
559
560    #[tokio::test]
561    async fn test_dev_mode_still_checks_expiration() {
562        let config = AuthConfig::dev_mode();
563        let middleware = AuthMiddleware::new(config);
564
565        let claims = create_test_claims(true); // Expired
566        let token = create_test_token(&claims, "any-secret");
567
568        let result = middleware.validate_token_async(&token).await;
569        assert!(result.is_err());
570        match result {
571            Err(AuthError::TokenExpired) => {}
572            _ => panic!("Expected TokenExpired error even in dev mode"),
573        }
574    }
575
576    #[tokio::test]
577    async fn test_invalid_token_format() {
578        let config = AuthConfig::with_secret("secret");
579        let middleware = AuthMiddleware::new(config);
580
581        let result = middleware.validate_token_async("not-a-valid-jwt").await;
582        assert!(result.is_err());
583        match result {
584            Err(AuthError::InvalidToken(_)) => {}
585            _ => panic!("Expected InvalidToken error"),
586        }
587    }
588
589    #[test]
590    fn test_algorithm_conversion() {
591        // HMAC algorithms
592        assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
593        assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
594        assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
595        // RSA algorithms
596        assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
597        assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
598        assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
599    }
600
601    #[test]
602    fn test_is_hmac_and_is_rsa() {
603        let hmac_config = AuthConfig::with_secret("test");
604        assert!(hmac_config.is_hmac());
605        assert!(!hmac_config.is_rsa());
606
607        let rsa_config = AuthConfig {
608            algorithm: JwtAlgorithm::RS256,
609            ..Default::default()
610        };
611        assert!(!rsa_config.is_hmac());
612        assert!(rsa_config.is_rsa());
613    }
614
615    #[test]
616    fn test_extract_token_rejects_non_bearer_header() {
617        let req = Request::builder()
618            .header(axum::http::header::AUTHORIZATION, "Basic abc")
619            .body(Body::empty())
620            .unwrap();
621
622        let result = extract_token(&req);
623        assert!(matches!(result, Err(AuthError::InvalidHeader)));
624    }
625
626    #[tokio::test]
627    async fn test_extract_auth_context_async_invalid_token_errors() {
628        let middleware = AuthMiddleware::new(AuthConfig::with_secret("secret"));
629        let result = extract_auth_context_async(Some("bad.token".to_string()), &middleware).await;
630        assert!(matches!(result, Err(AuthError::InvalidToken(_))));
631    }
632}