Skip to main content

forge_runtime/gateway/
auth.rs

1use std::sync::Arc;
2
3use axum::{
4    body::Body,
5    extract::{Request, State},
6    http::StatusCode,
7    middleware::Next,
8    response::{IntoResponse, Response},
9};
10use forge_core::auth::Claims;
11use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
12use forge_core::function::AuthContext;
13use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode};
14use tracing::debug;
15
16use super::jwks::JwksClient;
17
18/// Authentication configuration for the runtime.
19#[derive(Debug, Clone)]
20pub struct AuthConfig {
21    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
22    pub jwt_secret: Option<String>,
23    /// JWT algorithm.
24    pub algorithm: JwtAlgorithm,
25    /// JWKS client for RSA algorithms.
26    pub jwks_client: Option<Arc<JwksClient>>,
27    /// Expected token issuer (iss claim).
28    pub issuer: Option<String>,
29    /// Expected audience (aud claim).
30    pub audience: Option<String>,
31    /// Skip signature verification (DEV MODE ONLY - NEVER USE IN PRODUCTION).
32    pub skip_verification: bool,
33}
34
35impl Default for AuthConfig {
36    fn default() -> Self {
37        Self {
38            jwt_secret: None,
39            algorithm: JwtAlgorithm::HS256,
40            jwks_client: None,
41            issuer: None,
42            audience: None,
43            skip_verification: false,
44        }
45    }
46}
47
48impl AuthConfig {
49    /// Create auth config from forge core config.
50    pub fn from_forge_config(config: &forge_core::config::AuthConfig) -> Self {
51        let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
52
53        let jwks_client = config
54            .jwks_url
55            .as_ref()
56            .map(|url| Arc::new(JwksClient::new(url.clone(), config.jwks_cache_ttl_secs)));
57
58        Self {
59            jwt_secret: config.jwt_secret.clone(),
60            algorithm,
61            jwks_client,
62            issuer: config.jwt_issuer.clone(),
63            audience: config.jwt_audience.clone(),
64            skip_verification: false,
65        }
66    }
67
68    /// Create a new auth config with the given HMAC secret.
69    pub fn with_secret(secret: impl Into<String>) -> Self {
70        Self {
71            jwt_secret: Some(secret.into()),
72            ..Default::default()
73        }
74    }
75
76    /// Create a dev mode config that skips signature verification.
77    /// WARNING: Only use this for development and testing!
78    pub fn dev_mode() -> Self {
79        Self {
80            jwt_secret: None,
81            algorithm: JwtAlgorithm::HS256,
82            jwks_client: None,
83            issuer: None,
84            audience: None,
85            skip_verification: true,
86        }
87    }
88
89    /// Check if this config uses HMAC (symmetric) algorithms.
90    pub fn is_hmac(&self) -> bool {
91        matches!(
92            self.algorithm,
93            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
94        )
95    }
96
97    /// Check if this config uses RSA (asymmetric) algorithms.
98    pub fn is_rsa(&self) -> bool {
99        matches!(
100            self.algorithm,
101            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
102        )
103    }
104}
105
106/// Supported JWT algorithms.
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
108pub enum JwtAlgorithm {
109    #[default]
110    HS256,
111    HS384,
112    HS512,
113    RS256,
114    RS384,
115    RS512,
116}
117
118impl From<JwtAlgorithm> for Algorithm {
119    fn from(alg: JwtAlgorithm) -> Self {
120        match alg {
121            JwtAlgorithm::HS256 => Algorithm::HS256,
122            JwtAlgorithm::HS384 => Algorithm::HS384,
123            JwtAlgorithm::HS512 => Algorithm::HS512,
124            JwtAlgorithm::RS256 => Algorithm::RS256,
125            JwtAlgorithm::RS384 => Algorithm::RS384,
126            JwtAlgorithm::RS512 => Algorithm::RS512,
127        }
128    }
129}
130
131impl From<CoreJwtAlgorithm> for JwtAlgorithm {
132    fn from(alg: CoreJwtAlgorithm) -> Self {
133        match alg {
134            CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
135            CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
136            CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
137            CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
138            CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
139            CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
140        }
141    }
142}
143
144/// Authentication middleware.
145#[derive(Clone)]
146pub struct AuthMiddleware {
147    config: Arc<AuthConfig>,
148    /// Pre-computed HMAC decoding key (for performance).
149    hmac_key: Option<DecodingKey>,
150}
151
152impl std::fmt::Debug for AuthMiddleware {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        f.debug_struct("AuthMiddleware")
155            .field("config", &self.config)
156            .field("hmac_key", &self.hmac_key.is_some())
157            .finish()
158    }
159}
160
161impl AuthMiddleware {
162    /// Create a new auth middleware.
163    pub fn new(config: AuthConfig) -> Self {
164        // Pre-compute HMAC key if using HMAC algorithm
165        let hmac_key = if config.skip_verification {
166            None
167        } else if config.is_hmac() {
168            config
169                .jwt_secret
170                .as_ref()
171                .filter(|s| !s.is_empty())
172                .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
173        } else {
174            None
175        };
176
177        Self {
178            config: Arc::new(config),
179            hmac_key,
180        }
181    }
182
183    /// Create a middleware that allows all requests (development mode).
184    /// WARNING: This skips signature verification! Never use in production.
185    pub fn permissive() -> Self {
186        Self::new(AuthConfig::dev_mode())
187    }
188
189    /// Get the config.
190    pub fn config(&self) -> &AuthConfig {
191        &self.config
192    }
193
194    /// Validate a JWT token and extract claims.
195    pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
196        if self.config.skip_verification {
197            return self.decode_without_verification(token);
198        }
199
200        if self.config.is_hmac() {
201            self.validate_hmac(token)
202        } else {
203            self.validate_rsa(token).await
204        }
205    }
206
207    /// Validate HMAC-signed token.
208    fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
209        let key = self.hmac_key.as_ref().ok_or_else(|| {
210            AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
211        })?;
212
213        self.decode_and_validate(token, key)
214    }
215
216    /// Validate RSA-signed token using JWKS.
217    async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
218        let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
219            AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
220        })?;
221
222        // Extract key ID from token header
223        let header = jsonwebtoken::decode_header(token)
224            .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
225
226        debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
227
228        // Get key from JWKS
229        let key = if let Some(kid) = header.kid {
230            jwks.get_key(&kid).await.map_err(|e| {
231                AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
232            })?
233        } else {
234            jwks.get_any_key()
235                .await
236                .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
237        };
238
239        self.decode_and_validate(token, &key)
240    }
241
242    /// Decode and validate token with the given key.
243    fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
244        let mut validation = Validation::new(self.config.algorithm.into());
245
246        // Configure validation
247        validation.validate_exp = true;
248        validation.validate_nbf = false;
249        validation.leeway = 60; // 60 seconds clock skew tolerance
250
251        // Require exp and sub claims
252        validation.set_required_spec_claims(&["exp", "sub"]);
253
254        // Validate issuer if configured
255        if let Some(ref issuer) = self.config.issuer {
256            validation.set_issuer(&[issuer]);
257        }
258
259        // Validate audience if configured
260        if let Some(ref audience) = self.config.audience {
261            validation.set_audience(&[audience]);
262        } else {
263            validation.validate_aud = false;
264        }
265
266        let token_data =
267            decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
268
269        Ok(token_data.claims)
270    }
271
272    /// Map jsonwebtoken errors to AuthError.
273    fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
274        match e.kind() {
275            jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
276            jsonwebtoken::errors::ErrorKind::InvalidSignature => {
277                AuthError::InvalidToken("Invalid signature".to_string())
278            }
279            jsonwebtoken::errors::ErrorKind::InvalidToken => {
280                AuthError::InvalidToken("Invalid token format".to_string())
281            }
282            jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
283                AuthError::InvalidToken(format!("Missing required claim: {}", claim))
284            }
285            jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
286                AuthError::InvalidToken("Invalid issuer".to_string())
287            }
288            jsonwebtoken::errors::ErrorKind::InvalidAudience => {
289                AuthError::InvalidToken("Invalid audience".to_string())
290            }
291            _ => AuthError::InvalidToken(e.to_string()),
292        }
293    }
294
295    /// Decode JWT token without signature verification (DEV MODE ONLY).
296    fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
297        let token_data =
298            dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
299                jsonwebtoken::errors::ErrorKind::InvalidToken => {
300                    AuthError::InvalidToken("Invalid token format".to_string())
301                }
302                _ => AuthError::InvalidToken(e.to_string()),
303            })?;
304
305        // Still check expiration in dev mode
306        if token_data.claims.is_expired() {
307            return Err(AuthError::TokenExpired);
308        }
309
310        Ok(token_data.claims)
311    }
312}
313
314/// Authentication errors.
315#[derive(Debug, Clone, thiserror::Error)]
316pub enum AuthError {
317    #[error("Missing authorization header")]
318    MissingHeader,
319    #[error("Invalid authorization header format")]
320    InvalidHeader,
321    #[error("Invalid token: {0}")]
322    InvalidToken(String),
323    #[error("Token expired")]
324    TokenExpired,
325}
326
327/// Extract token from request headers.
328pub fn extract_token(req: &Request<Body>) -> Result<Option<String>, AuthError> {
329    let Some(header_value) = req.headers().get(axum::http::header::AUTHORIZATION) else {
330        return Ok(None);
331    };
332
333    let header = header_value
334        .to_str()
335        .map_err(|_| AuthError::InvalidHeader)?;
336    let token = header
337        .strip_prefix("Bearer ")
338        .ok_or(AuthError::InvalidHeader)?
339        .trim();
340
341    if token.is_empty() {
342        return Err(AuthError::InvalidHeader);
343    }
344
345    Ok(Some(token.to_string()))
346}
347
348/// Extract auth context from token (async, supports both HMAC and RSA/JWKS).
349pub async fn extract_auth_context_async(
350    token: Option<String>,
351    middleware: &AuthMiddleware,
352) -> Result<AuthContext, AuthError> {
353    match token {
354        Some(token) => middleware
355            .validate_token_async(&token)
356            .await
357            .map(build_auth_context_from_claims),
358        None => Ok(AuthContext::unauthenticated()),
359    }
360}
361
362/// Build auth context from validated claims.
363///
364/// This handles both UUID and non-UUID subjects properly:
365/// - UUID subjects: uses `authenticated()` with the parsed UUID
366/// - Non-UUID subjects: uses `authenticated_without_uuid()` and stores raw subject in claims
367pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
368    // Try to parse subject as UUID first (before moving claims)
369    let user_id = claims.user_id();
370
371    // Build custom claims with raw subject included
372    let mut custom_claims = claims.custom;
373    custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
374
375    match user_id {
376        Some(uuid) => {
377            // Subject is a valid UUID
378            AuthContext::authenticated(uuid, claims.roles, custom_claims)
379        }
380        None => {
381            // Subject is not a UUID (e.g., Firebase uid, Clerk user_xxx, email)
382            // Still authenticated, but user_id() will return None
383            AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
384        }
385    }
386}
387
388/// Authentication middleware function.
389pub async fn auth_middleware(
390    State(middleware): State<Arc<AuthMiddleware>>,
391    req: Request<Body>,
392    next: Next,
393) -> Response {
394    let token = match extract_token(&req) {
395        Ok(token) => token,
396        Err(e) => {
397            tracing::warn!(error = %e, "Invalid authorization header");
398            return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response();
399        }
400    };
401    tracing::trace!(
402        token_present = token.is_some(),
403        "Auth middleware processing request"
404    );
405
406    let auth_context = match extract_auth_context_async(token, &middleware).await {
407        Ok(auth_context) => auth_context,
408        Err(e) => {
409            tracing::warn!(error = %e, "Token validation failed");
410            return (StatusCode::UNAUTHORIZED, "Invalid authentication token").into_response();
411        }
412    };
413    tracing::trace!(
414        authenticated = auth_context.is_authenticated(),
415        "Auth context created"
416    );
417
418    let mut req = req;
419    req.extensions_mut().insert(auth_context);
420
421    next.run(req).await
422}
423
424#[cfg(test)]
425#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
426mod tests {
427    use super::*;
428    use jsonwebtoken::{EncodingKey, Header, encode};
429
430    fn create_test_claims(expired: bool) -> Claims {
431        use forge_core::auth::ClaimsBuilder;
432
433        let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
434
435        if expired {
436            builder = builder.duration_secs(-3600); // Expired 1 hour ago
437        } else {
438            builder = builder.duration_secs(3600); // Valid for 1 hour
439        }
440
441        builder.build().unwrap()
442    }
443
444    fn create_test_token(claims: &Claims, secret: &str) -> String {
445        encode(
446            &Header::default(),
447            claims,
448            &EncodingKey::from_secret(secret.as_bytes()),
449        )
450        .unwrap()
451    }
452
453    #[test]
454    fn test_auth_config_default() {
455        let config = AuthConfig::default();
456        assert_eq!(config.algorithm, JwtAlgorithm::HS256);
457        assert!(!config.skip_verification);
458    }
459
460    #[test]
461    fn test_auth_config_dev_mode() {
462        let config = AuthConfig::dev_mode();
463        assert!(config.skip_verification);
464    }
465
466    #[test]
467    fn test_auth_middleware_permissive() {
468        let middleware = AuthMiddleware::permissive();
469        assert!(middleware.config.skip_verification);
470    }
471
472    #[tokio::test]
473    async fn test_valid_token_with_correct_secret() {
474        let secret = "test-secret-key";
475        let config = AuthConfig::with_secret(secret);
476        let middleware = AuthMiddleware::new(config);
477
478        let claims = create_test_claims(false);
479        let token = create_test_token(&claims, secret);
480
481        let result = middleware.validate_token_async(&token).await;
482        assert!(result.is_ok());
483        let validated_claims = result.unwrap();
484        assert_eq!(validated_claims.sub, "test-user-id");
485    }
486
487    #[tokio::test]
488    async fn test_valid_token_with_wrong_secret() {
489        let config = AuthConfig::with_secret("correct-secret");
490        let middleware = AuthMiddleware::new(config);
491
492        let claims = create_test_claims(false);
493        let token = create_test_token(&claims, "wrong-secret");
494
495        let result = middleware.validate_token_async(&token).await;
496        assert!(result.is_err());
497        match result {
498            Err(AuthError::InvalidToken(_)) => {}
499            _ => panic!("Expected InvalidToken error"),
500        }
501    }
502
503    #[tokio::test]
504    async fn test_expired_token() {
505        let secret = "test-secret";
506        let config = AuthConfig::with_secret(secret);
507        let middleware = AuthMiddleware::new(config);
508
509        let claims = create_test_claims(true); // Expired
510        let token = create_test_token(&claims, secret);
511
512        let result = middleware.validate_token_async(&token).await;
513        assert!(result.is_err());
514        match result {
515            Err(AuthError::TokenExpired) => {}
516            _ => panic!("Expected TokenExpired error"),
517        }
518    }
519
520    #[tokio::test]
521    async fn test_tampered_token() {
522        let secret = "test-secret";
523        let config = AuthConfig::with_secret(secret);
524        let middleware = AuthMiddleware::new(config);
525
526        let claims = create_test_claims(false);
527        let mut token = create_test_token(&claims, secret);
528
529        // Tamper with the token by modifying a character in the signature
530        if let Some(last_char) = token.pop() {
531            let replacement = if last_char == 'a' { 'b' } else { 'a' };
532            token.push(replacement);
533        }
534
535        let result = middleware.validate_token_async(&token).await;
536        assert!(result.is_err());
537    }
538
539    #[tokio::test]
540    async fn test_dev_mode_skips_signature() {
541        let config = AuthConfig::dev_mode();
542        let middleware = AuthMiddleware::new(config);
543
544        // Create token with any secret
545        let claims = create_test_claims(false);
546        let token = create_test_token(&claims, "any-secret");
547
548        // Should still validate in dev mode
549        let result = middleware.validate_token_async(&token).await;
550        assert!(result.is_ok());
551    }
552
553    #[tokio::test]
554    async fn test_dev_mode_still_checks_expiration() {
555        let config = AuthConfig::dev_mode();
556        let middleware = AuthMiddleware::new(config);
557
558        let claims = create_test_claims(true); // Expired
559        let token = create_test_token(&claims, "any-secret");
560
561        let result = middleware.validate_token_async(&token).await;
562        assert!(result.is_err());
563        match result {
564            Err(AuthError::TokenExpired) => {}
565            _ => panic!("Expected TokenExpired error even in dev mode"),
566        }
567    }
568
569    #[tokio::test]
570    async fn test_invalid_token_format() {
571        let config = AuthConfig::with_secret("secret");
572        let middleware = AuthMiddleware::new(config);
573
574        let result = middleware.validate_token_async("not-a-valid-jwt").await;
575        assert!(result.is_err());
576        match result {
577            Err(AuthError::InvalidToken(_)) => {}
578            _ => panic!("Expected InvalidToken error"),
579        }
580    }
581
582    #[test]
583    fn test_algorithm_conversion() {
584        // HMAC algorithms
585        assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
586        assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
587        assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
588        // RSA algorithms
589        assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
590        assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
591        assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
592    }
593
594    #[test]
595    fn test_is_hmac_and_is_rsa() {
596        let hmac_config = AuthConfig::with_secret("test");
597        assert!(hmac_config.is_hmac());
598        assert!(!hmac_config.is_rsa());
599
600        let rsa_config = AuthConfig {
601            algorithm: JwtAlgorithm::RS256,
602            ..Default::default()
603        };
604        assert!(!rsa_config.is_hmac());
605        assert!(rsa_config.is_rsa());
606    }
607
608    #[test]
609    fn test_extract_token_rejects_non_bearer_header() {
610        let req = Request::builder()
611            .header(axum::http::header::AUTHORIZATION, "Basic abc")
612            .body(Body::empty())
613            .unwrap();
614
615        let result = extract_token(&req);
616        assert!(matches!(result, Err(AuthError::InvalidHeader)));
617    }
618
619    #[tokio::test]
620    async fn test_extract_auth_context_async_invalid_token_errors() {
621        let middleware = AuthMiddleware::new(AuthConfig::with_secret("secret"));
622        let result = extract_auth_context_async(Some("bad.token".to_string()), &middleware).await;
623        assert!(matches!(result, Err(AuthError::InvalidToken(_))));
624    }
625}