Skip to main content

forge_runtime/gateway/
auth.rs

1use std::sync::Arc;
2
3use axum::{
4    body::Body,
5    extract::{Request, State},
6    http::StatusCode,
7    middleware::Next,
8    response::{IntoResponse, Json, Response},
9};
10use forge_core::auth::Claims;
11use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
12use forge_core::function::AuthContext;
13use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode, encode};
14use tracing::debug;
15
16use super::jwks::JwksClient;
17
18/// Authentication configuration for the runtime.
19#[derive(Debug, Clone)]
20pub struct AuthConfig {
21    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
22    pub jwt_secret: Option<String>,
23    /// JWT algorithm.
24    pub algorithm: JwtAlgorithm,
25    /// JWKS client for RSA algorithms.
26    pub jwks_client: Option<Arc<JwksClient>>,
27    /// Expected token issuer (iss claim).
28    pub issuer: Option<String>,
29    /// Expected audience (aud claim).
30    pub audience: Option<String>,
31    /// Skip signature verification (DEV MODE ONLY - NEVER USE IN PRODUCTION).
32    pub skip_verification: bool,
33}
34
35impl Default for AuthConfig {
36    fn default() -> Self {
37        Self {
38            jwt_secret: None,
39            algorithm: JwtAlgorithm::HS256,
40            jwks_client: None,
41            issuer: None,
42            audience: None,
43            skip_verification: false,
44        }
45    }
46}
47
48impl AuthConfig {
49    /// Create auth config from forge core config.
50    pub fn from_forge_config(
51        config: &forge_core::config::AuthConfig,
52    ) -> Result<Self, super::jwks::JwksError> {
53        let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
54
55        let jwks_client = config
56            .jwks_url
57            .as_ref()
58            .map(|url| JwksClient::new(url.clone(), config.jwks_cache_ttl_secs).map(Arc::new))
59            .transpose()?;
60
61        Ok(Self {
62            jwt_secret: config.jwt_secret.clone(),
63            algorithm,
64            jwks_client,
65            issuer: config.jwt_issuer.clone(),
66            audience: config.jwt_audience.clone(),
67            skip_verification: false,
68        })
69    }
70
71    /// Create a new auth config with the given HMAC secret.
72    pub fn with_secret(secret: impl Into<String>) -> Self {
73        Self {
74            jwt_secret: Some(secret.into()),
75            ..Default::default()
76        }
77    }
78
79    /// Create a dev mode config that skips signature verification.
80    /// WARNING: Only use this for development and testing!
81    pub fn dev_mode() -> Self {
82        Self {
83            jwt_secret: None,
84            algorithm: JwtAlgorithm::HS256,
85            jwks_client: None,
86            issuer: None,
87            audience: None,
88            skip_verification: true,
89        }
90    }
91
92    /// Check if this config uses HMAC (symmetric) algorithms.
93    pub fn is_hmac(&self) -> bool {
94        matches!(
95            self.algorithm,
96            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
97        )
98    }
99
100    /// Check if this config uses RSA (asymmetric) algorithms.
101    pub fn is_rsa(&self) -> bool {
102        matches!(
103            self.algorithm,
104            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
105        )
106    }
107}
108
109/// Supported JWT algorithms.
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub enum JwtAlgorithm {
112    #[default]
113    HS256,
114    HS384,
115    HS512,
116    RS256,
117    RS384,
118    RS512,
119}
120
121impl From<JwtAlgorithm> for Algorithm {
122    fn from(alg: JwtAlgorithm) -> Self {
123        match alg {
124            JwtAlgorithm::HS256 => Algorithm::HS256,
125            JwtAlgorithm::HS384 => Algorithm::HS384,
126            JwtAlgorithm::HS512 => Algorithm::HS512,
127            JwtAlgorithm::RS256 => Algorithm::RS256,
128            JwtAlgorithm::RS384 => Algorithm::RS384,
129            JwtAlgorithm::RS512 => Algorithm::RS512,
130        }
131    }
132}
133
134impl From<CoreJwtAlgorithm> for JwtAlgorithm {
135    fn from(alg: CoreJwtAlgorithm) -> Self {
136        match alg {
137            CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
138            CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
139            CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
140            CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
141            CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
142            CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
143        }
144    }
145}
146
147/// Token issuer for HMAC-based JWT signing.
148///
149/// Created from the auth config when an HMAC algorithm is configured.
150/// Passed into MutationContext so handlers can call `ctx.issue_token()`.
151#[derive(Clone)]
152pub struct HmacTokenIssuer {
153    secret: String,
154    algorithm: Algorithm,
155}
156
157impl HmacTokenIssuer {
158    /// Create a token issuer from auth config, if HMAC auth is configured.
159    pub fn from_config(config: &AuthConfig) -> Option<Self> {
160        if !config.is_hmac() {
161            return None;
162        }
163        let secret = config.jwt_secret.as_ref()?.clone();
164        if secret.is_empty() {
165            return None;
166        }
167        Some(Self {
168            secret,
169            algorithm: config.algorithm.into(),
170        })
171    }
172}
173
174impl forge_core::TokenIssuer for HmacTokenIssuer {
175    fn sign(&self, claims: &Claims) -> forge_core::Result<String> {
176        let header = jsonwebtoken::Header::new(self.algorithm);
177        encode(
178            &header,
179            claims,
180            &jsonwebtoken::EncodingKey::from_secret(self.secret.as_bytes()),
181        )
182        .map_err(|e| forge_core::ForgeError::Internal(format!("token signing error: {e}")))
183    }
184}
185
186/// Authentication middleware.
187#[derive(Clone)]
188pub struct AuthMiddleware {
189    config: Arc<AuthConfig>,
190    /// Pre-computed HMAC decoding key (for performance).
191    hmac_key: Option<DecodingKey>,
192}
193
194impl std::fmt::Debug for AuthMiddleware {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        f.debug_struct("AuthMiddleware")
197            .field("config", &self.config)
198            .field("hmac_key", &self.hmac_key.is_some())
199            .finish()
200    }
201}
202
203impl AuthMiddleware {
204    /// Create a new auth middleware.
205    pub fn new(config: AuthConfig) -> Self {
206        if config.skip_verification {
207            tracing::warn!("JWT signature verification is DISABLED. Do not use in production.");
208        }
209
210        // Pre-compute HMAC key if using HMAC algorithm
211        let hmac_key = if config.skip_verification {
212            None
213        } else if config.is_hmac() {
214            config
215                .jwt_secret
216                .as_ref()
217                .filter(|s| !s.is_empty())
218                .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
219        } else {
220            None
221        };
222
223        Self {
224            config: Arc::new(config),
225            hmac_key,
226        }
227    }
228
229    /// Create a middleware that allows all requests (development mode).
230    /// WARNING: This skips signature verification! Never use in production.
231    pub fn permissive() -> Self {
232        Self::new(AuthConfig::dev_mode())
233    }
234
235    /// Get the config.
236    pub fn config(&self) -> &AuthConfig {
237        &self.config
238    }
239
240    /// Validate a JWT token and extract claims.
241    pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
242        if self.config.skip_verification {
243            return self.decode_without_verification(token);
244        }
245
246        if self.config.is_hmac() {
247            self.validate_hmac(token)
248        } else {
249            self.validate_rsa(token).await
250        }
251    }
252
253    /// Validate HMAC-signed token.
254    fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
255        let key = self.hmac_key.as_ref().ok_or_else(|| {
256            AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
257        })?;
258
259        self.decode_and_validate(token, key)
260    }
261
262    /// Validate RSA-signed token using JWKS.
263    async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
264        let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
265            AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
266        })?;
267
268        // Extract key ID from token header
269        let header = jsonwebtoken::decode_header(token)
270            .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
271
272        debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
273
274        // Get key from JWKS
275        let key = if let Some(kid) = header.kid {
276            jwks.get_key(&kid).await.map_err(|e| {
277                AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
278            })?
279        } else {
280            jwks.get_any_key()
281                .await
282                .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
283        };
284
285        self.decode_and_validate(token, &key)
286    }
287
288    /// Decode and validate token with the given key.
289    fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
290        let mut validation = Validation::new(self.config.algorithm.into());
291
292        // Configure validation
293        validation.validate_exp = true;
294        validation.validate_nbf = false;
295        validation.leeway = 60; // 60 seconds clock skew tolerance
296
297        // Require exp and sub claims
298        validation.set_required_spec_claims(&["exp", "sub"]);
299
300        // Validate issuer if configured
301        if let Some(ref issuer) = self.config.issuer {
302            validation.set_issuer(&[issuer]);
303        }
304
305        // Validate audience if configured
306        if let Some(ref audience) = self.config.audience {
307            validation.set_audience(&[audience]);
308        } else {
309            validation.validate_aud = false;
310        }
311
312        let token_data =
313            decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
314
315        Ok(token_data.claims)
316    }
317
318    /// Map jsonwebtoken errors to AuthError.
319    fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
320        match e.kind() {
321            jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
322            jsonwebtoken::errors::ErrorKind::InvalidSignature => {
323                AuthError::InvalidToken("Invalid signature".to_string())
324            }
325            jsonwebtoken::errors::ErrorKind::InvalidToken => {
326                AuthError::InvalidToken("Invalid token format".to_string())
327            }
328            jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
329                AuthError::InvalidToken(format!("Missing required claim: {}", claim))
330            }
331            jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
332                AuthError::InvalidToken("Invalid issuer".to_string())
333            }
334            jsonwebtoken::errors::ErrorKind::InvalidAudience => {
335                AuthError::InvalidToken("Invalid audience".to_string())
336            }
337            _ => AuthError::InvalidToken(e.to_string()),
338        }
339    }
340
341    /// Decode JWT token without signature verification (DEV MODE ONLY).
342    fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
343        let token_data =
344            dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
345                jsonwebtoken::errors::ErrorKind::InvalidToken => {
346                    AuthError::InvalidToken("Invalid token format".to_string())
347                }
348                _ => AuthError::InvalidToken(e.to_string()),
349            })?;
350
351        // Still check expiration in dev mode
352        if token_data.claims.is_expired() {
353            return Err(AuthError::TokenExpired);
354        }
355
356        Ok(token_data.claims)
357    }
358}
359
360/// Authentication errors.
361#[derive(Debug, Clone, thiserror::Error)]
362pub enum AuthError {
363    #[error("Missing authorization header")]
364    MissingHeader,
365    #[error("Invalid authorization header format")]
366    InvalidHeader,
367    #[error("Invalid token: {0}")]
368    InvalidToken(String),
369    #[error("Token expired")]
370    TokenExpired,
371}
372
373/// Extract token from request headers.
374pub fn extract_token(req: &Request<Body>) -> Result<Option<String>, AuthError> {
375    let Some(header_value) = req.headers().get(axum::http::header::AUTHORIZATION) else {
376        return Ok(None);
377    };
378
379    let header = header_value
380        .to_str()
381        .map_err(|_| AuthError::InvalidHeader)?;
382    let token = header
383        .strip_prefix("Bearer ")
384        .ok_or(AuthError::InvalidHeader)?
385        .trim();
386
387    if token.is_empty() {
388        return Err(AuthError::InvalidHeader);
389    }
390
391    Ok(Some(token.to_string()))
392}
393
394/// Extract auth context from token (async, supports both HMAC and RSA/JWKS).
395pub async fn extract_auth_context_async(
396    token: Option<String>,
397    middleware: &AuthMiddleware,
398) -> Result<AuthContext, AuthError> {
399    match token {
400        Some(token) => middleware
401            .validate_token_async(&token)
402            .await
403            .map(build_auth_context_from_claims),
404        None => Ok(AuthContext::unauthenticated()),
405    }
406}
407
408/// Build auth context from validated claims.
409///
410/// This handles both UUID and non-UUID subjects properly:
411/// - UUID subjects: uses `authenticated()` with the parsed UUID
412/// - Non-UUID subjects: uses `authenticated_without_uuid()` and stores raw subject in claims
413pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
414    // Try to parse subject as UUID first (before moving claims)
415    let user_id = claims.user_id();
416
417    // Build custom claims with raw subject included
418    let mut custom_claims = claims.custom;
419    custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
420
421    match user_id {
422        Some(uuid) => {
423            // Subject is a valid UUID
424            AuthContext::authenticated(uuid, claims.roles, custom_claims)
425        }
426        None => {
427            // Subject is not a UUID (e.g., Firebase uid, Clerk user_xxx, email)
428            // Still authenticated, but user_id() will return None
429            AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
430        }
431    }
432}
433
434/// Authentication middleware function.
435pub async fn auth_middleware(
436    State(middleware): State<Arc<AuthMiddleware>>,
437    req: Request<Body>,
438    next: Next,
439) -> Response {
440    let token = match extract_token(&req) {
441        Ok(token) => token,
442        Err(e) => {
443            tracing::warn!(error = %e, "Invalid authorization header");
444            return (
445                StatusCode::UNAUTHORIZED,
446                Json(serde_json::json!({
447                    "success": false,
448                    "error": { "code": "UNAUTHORIZED", "message": "Invalid authorization header" }
449                })),
450            )
451                .into_response();
452        }
453    };
454    tracing::trace!(
455        token_present = token.is_some(),
456        "Auth middleware processing request"
457    );
458
459    let auth_context = match extract_auth_context_async(token, &middleware).await {
460        Ok(auth_context) => auth_context,
461        Err(e) => {
462            tracing::warn!(error = %e, "Token validation failed");
463            return (
464                StatusCode::UNAUTHORIZED,
465                Json(serde_json::json!({
466                    "success": false,
467                    "error": { "code": "UNAUTHORIZED", "message": "Invalid authentication token" }
468                })),
469            )
470                .into_response();
471        }
472    };
473    tracing::trace!(
474        authenticated = auth_context.is_authenticated(),
475        "Auth context created"
476    );
477
478    let mut req = req;
479    req.extensions_mut().insert(auth_context);
480
481    next.run(req).await
482}
483
484#[cfg(test)]
485#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
486mod tests {
487    use super::*;
488    use jsonwebtoken::{EncodingKey, Header, encode};
489
490    fn create_test_claims(expired: bool) -> Claims {
491        use forge_core::auth::ClaimsBuilder;
492
493        let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
494
495        if expired {
496            builder = builder.duration_secs(-3600); // Expired 1 hour ago
497        } else {
498            builder = builder.duration_secs(3600); // Valid for 1 hour
499        }
500
501        builder.build().unwrap()
502    }
503
504    fn create_test_token(claims: &Claims, secret: &str) -> String {
505        encode(
506            &Header::default(),
507            claims,
508            &EncodingKey::from_secret(secret.as_bytes()),
509        )
510        .unwrap()
511    }
512
513    #[test]
514    fn test_auth_config_default() {
515        let config = AuthConfig::default();
516        assert_eq!(config.algorithm, JwtAlgorithm::HS256);
517        assert!(!config.skip_verification);
518    }
519
520    #[test]
521    fn test_auth_config_dev_mode() {
522        let config = AuthConfig::dev_mode();
523        assert!(config.skip_verification);
524    }
525
526    #[test]
527    fn test_auth_middleware_permissive() {
528        let middleware = AuthMiddleware::permissive();
529        assert!(middleware.config.skip_verification);
530    }
531
532    #[tokio::test]
533    async fn test_valid_token_with_correct_secret() {
534        let secret = "test-secret-key";
535        let config = AuthConfig::with_secret(secret);
536        let middleware = AuthMiddleware::new(config);
537
538        let claims = create_test_claims(false);
539        let token = create_test_token(&claims, secret);
540
541        let result = middleware.validate_token_async(&token).await;
542        assert!(result.is_ok());
543        let validated_claims = result.unwrap();
544        assert_eq!(validated_claims.sub, "test-user-id");
545    }
546
547    #[tokio::test]
548    async fn test_valid_token_with_wrong_secret() {
549        let config = AuthConfig::with_secret("correct-secret");
550        let middleware = AuthMiddleware::new(config);
551
552        let claims = create_test_claims(false);
553        let token = create_test_token(&claims, "wrong-secret");
554
555        let result = middleware.validate_token_async(&token).await;
556        assert!(result.is_err());
557        match result {
558            Err(AuthError::InvalidToken(_)) => {}
559            _ => panic!("Expected InvalidToken error"),
560        }
561    }
562
563    #[tokio::test]
564    async fn test_expired_token() {
565        let secret = "test-secret";
566        let config = AuthConfig::with_secret(secret);
567        let middleware = AuthMiddleware::new(config);
568
569        let claims = create_test_claims(true); // Expired
570        let token = create_test_token(&claims, secret);
571
572        let result = middleware.validate_token_async(&token).await;
573        assert!(result.is_err());
574        match result {
575            Err(AuthError::TokenExpired) => {}
576            _ => panic!("Expected TokenExpired error"),
577        }
578    }
579
580    #[tokio::test]
581    async fn test_tampered_token() {
582        let secret = "test-secret";
583        let config = AuthConfig::with_secret(secret);
584        let middleware = AuthMiddleware::new(config);
585
586        let claims = create_test_claims(false);
587        let mut token = create_test_token(&claims, secret);
588
589        // Tamper with the token by modifying a character in the signature
590        if let Some(last_char) = token.pop() {
591            let replacement = if last_char == 'a' { 'b' } else { 'a' };
592            token.push(replacement);
593        }
594
595        let result = middleware.validate_token_async(&token).await;
596        assert!(result.is_err());
597    }
598
599    #[tokio::test]
600    async fn test_dev_mode_skips_signature() {
601        let config = AuthConfig::dev_mode();
602        let middleware = AuthMiddleware::new(config);
603
604        // Create token with any secret
605        let claims = create_test_claims(false);
606        let token = create_test_token(&claims, "any-secret");
607
608        // Should still validate in dev mode
609        let result = middleware.validate_token_async(&token).await;
610        assert!(result.is_ok());
611    }
612
613    #[tokio::test]
614    async fn test_dev_mode_still_checks_expiration() {
615        let config = AuthConfig::dev_mode();
616        let middleware = AuthMiddleware::new(config);
617
618        let claims = create_test_claims(true); // Expired
619        let token = create_test_token(&claims, "any-secret");
620
621        let result = middleware.validate_token_async(&token).await;
622        assert!(result.is_err());
623        match result {
624            Err(AuthError::TokenExpired) => {}
625            _ => panic!("Expected TokenExpired error even in dev mode"),
626        }
627    }
628
629    #[tokio::test]
630    async fn test_invalid_token_format() {
631        let config = AuthConfig::with_secret("secret");
632        let middleware = AuthMiddleware::new(config);
633
634        let result = middleware.validate_token_async("not-a-valid-jwt").await;
635        assert!(result.is_err());
636        match result {
637            Err(AuthError::InvalidToken(_)) => {}
638            _ => panic!("Expected InvalidToken error"),
639        }
640    }
641
642    #[test]
643    fn test_algorithm_conversion() {
644        // HMAC algorithms
645        assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
646        assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
647        assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
648        // RSA algorithms
649        assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
650        assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
651        assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
652    }
653
654    #[test]
655    fn test_is_hmac_and_is_rsa() {
656        let hmac_config = AuthConfig::with_secret("test");
657        assert!(hmac_config.is_hmac());
658        assert!(!hmac_config.is_rsa());
659
660        let rsa_config = AuthConfig {
661            algorithm: JwtAlgorithm::RS256,
662            ..Default::default()
663        };
664        assert!(!rsa_config.is_hmac());
665        assert!(rsa_config.is_rsa());
666    }
667
668    #[test]
669    fn test_extract_token_rejects_non_bearer_header() {
670        let req = Request::builder()
671            .header(axum::http::header::AUTHORIZATION, "Basic abc")
672            .body(Body::empty())
673            .unwrap();
674
675        let result = extract_token(&req);
676        assert!(matches!(result, Err(AuthError::InvalidHeader)));
677    }
678
679    #[tokio::test]
680    async fn test_extract_auth_context_async_invalid_token_errors() {
681        let middleware = AuthMiddleware::new(AuthConfig::with_secret("secret"));
682        let result = extract_auth_context_async(Some("bad.token".to_string()), &middleware).await;
683        assert!(matches!(result, Err(AuthError::InvalidToken(_))));
684    }
685}