Skip to main content

forge_runtime/gateway/
auth.rs

1use std::sync::Arc;
2
3use axum::{
4    body::Body,
5    extract::{Request, State},
6    http::StatusCode,
7    middleware::Next,
8    response::{IntoResponse, Json, Response},
9};
10use forge_core::auth::Claims;
11use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
12use forge_core::function::AuthContext;
13use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode};
14use tracing::debug;
15
16use super::jwks::JwksClient;
17
18/// Authentication configuration for the runtime.
19#[derive(Debug, Clone)]
20pub struct AuthConfig {
21    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
22    pub jwt_secret: Option<String>,
23    /// JWT algorithm.
24    pub algorithm: JwtAlgorithm,
25    /// JWKS client for RSA algorithms.
26    pub jwks_client: Option<Arc<JwksClient>>,
27    /// Expected token issuer (iss claim).
28    pub issuer: Option<String>,
29    /// Expected audience (aud claim).
30    pub audience: Option<String>,
31    /// Skip signature verification (DEV MODE ONLY - NEVER USE IN PRODUCTION).
32    pub skip_verification: bool,
33}
34
35impl Default for AuthConfig {
36    fn default() -> Self {
37        Self {
38            jwt_secret: None,
39            algorithm: JwtAlgorithm::HS256,
40            jwks_client: None,
41            issuer: None,
42            audience: None,
43            skip_verification: false,
44        }
45    }
46}
47
48impl AuthConfig {
49    /// Create auth config from forge core config.
50    pub fn from_forge_config(
51        config: &forge_core::config::AuthConfig,
52    ) -> Result<Self, super::jwks::JwksError> {
53        let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
54
55        let jwks_client = config
56            .jwks_url
57            .as_ref()
58            .map(|url| JwksClient::new(url.clone(), config.jwks_cache_ttl_secs).map(Arc::new))
59            .transpose()?;
60
61        Ok(Self {
62            jwt_secret: config.jwt_secret.clone(),
63            algorithm,
64            jwks_client,
65            issuer: config.jwt_issuer.clone(),
66            audience: config.jwt_audience.clone(),
67            skip_verification: false,
68        })
69    }
70
71    /// Create a new auth config with the given HMAC secret.
72    pub fn with_secret(secret: impl Into<String>) -> Self {
73        Self {
74            jwt_secret: Some(secret.into()),
75            ..Default::default()
76        }
77    }
78
79    /// Create a dev mode config that skips signature verification.
80    /// WARNING: Only use this for development and testing!
81    pub fn dev_mode() -> Self {
82        Self {
83            jwt_secret: None,
84            algorithm: JwtAlgorithm::HS256,
85            jwks_client: None,
86            issuer: None,
87            audience: None,
88            skip_verification: true,
89        }
90    }
91
92    /// Check if this config uses HMAC (symmetric) algorithms.
93    pub fn is_hmac(&self) -> bool {
94        matches!(
95            self.algorithm,
96            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
97        )
98    }
99
100    /// Check if this config uses RSA (asymmetric) algorithms.
101    pub fn is_rsa(&self) -> bool {
102        matches!(
103            self.algorithm,
104            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
105        )
106    }
107}
108
109/// Supported JWT algorithms.
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub enum JwtAlgorithm {
112    #[default]
113    HS256,
114    HS384,
115    HS512,
116    RS256,
117    RS384,
118    RS512,
119}
120
121impl From<JwtAlgorithm> for Algorithm {
122    fn from(alg: JwtAlgorithm) -> Self {
123        match alg {
124            JwtAlgorithm::HS256 => Algorithm::HS256,
125            JwtAlgorithm::HS384 => Algorithm::HS384,
126            JwtAlgorithm::HS512 => Algorithm::HS512,
127            JwtAlgorithm::RS256 => Algorithm::RS256,
128            JwtAlgorithm::RS384 => Algorithm::RS384,
129            JwtAlgorithm::RS512 => Algorithm::RS512,
130        }
131    }
132}
133
134impl From<CoreJwtAlgorithm> for JwtAlgorithm {
135    fn from(alg: CoreJwtAlgorithm) -> Self {
136        match alg {
137            CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
138            CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
139            CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
140            CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
141            CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
142            CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
143        }
144    }
145}
146
147/// Authentication middleware.
148#[derive(Clone)]
149pub struct AuthMiddleware {
150    config: Arc<AuthConfig>,
151    /// Pre-computed HMAC decoding key (for performance).
152    hmac_key: Option<DecodingKey>,
153}
154
155impl std::fmt::Debug for AuthMiddleware {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        f.debug_struct("AuthMiddleware")
158            .field("config", &self.config)
159            .field("hmac_key", &self.hmac_key.is_some())
160            .finish()
161    }
162}
163
164impl AuthMiddleware {
165    /// Create a new auth middleware.
166    pub fn new(config: AuthConfig) -> Self {
167        if config.skip_verification {
168            tracing::warn!("JWT signature verification is DISABLED. Do not use in production.");
169        }
170
171        // Pre-compute HMAC key if using HMAC algorithm
172        let hmac_key = if config.skip_verification {
173            None
174        } else if config.is_hmac() {
175            config
176                .jwt_secret
177                .as_ref()
178                .filter(|s| !s.is_empty())
179                .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
180        } else {
181            None
182        };
183
184        Self {
185            config: Arc::new(config),
186            hmac_key,
187        }
188    }
189
190    /// Create a middleware that allows all requests (development mode).
191    /// WARNING: This skips signature verification! Never use in production.
192    pub fn permissive() -> Self {
193        Self::new(AuthConfig::dev_mode())
194    }
195
196    /// Get the config.
197    pub fn config(&self) -> &AuthConfig {
198        &self.config
199    }
200
201    /// Validate a JWT token and extract claims.
202    pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
203        if self.config.skip_verification {
204            return self.decode_without_verification(token);
205        }
206
207        if self.config.is_hmac() {
208            self.validate_hmac(token)
209        } else {
210            self.validate_rsa(token).await
211        }
212    }
213
214    /// Validate HMAC-signed token.
215    fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
216        let key = self.hmac_key.as_ref().ok_or_else(|| {
217            AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
218        })?;
219
220        self.decode_and_validate(token, key)
221    }
222
223    /// Validate RSA-signed token using JWKS.
224    async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
225        let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
226            AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
227        })?;
228
229        // Extract key ID from token header
230        let header = jsonwebtoken::decode_header(token)
231            .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
232
233        debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
234
235        // Get key from JWKS
236        let key = if let Some(kid) = header.kid {
237            jwks.get_key(&kid).await.map_err(|e| {
238                AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
239            })?
240        } else {
241            jwks.get_any_key()
242                .await
243                .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
244        };
245
246        self.decode_and_validate(token, &key)
247    }
248
249    /// Decode and validate token with the given key.
250    fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
251        let mut validation = Validation::new(self.config.algorithm.into());
252
253        // Configure validation
254        validation.validate_exp = true;
255        validation.validate_nbf = false;
256        validation.leeway = 60; // 60 seconds clock skew tolerance
257
258        // Require exp and sub claims
259        validation.set_required_spec_claims(&["exp", "sub"]);
260
261        // Validate issuer if configured
262        if let Some(ref issuer) = self.config.issuer {
263            validation.set_issuer(&[issuer]);
264        }
265
266        // Validate audience if configured
267        if let Some(ref audience) = self.config.audience {
268            validation.set_audience(&[audience]);
269        } else {
270            validation.validate_aud = false;
271        }
272
273        let token_data =
274            decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
275
276        Ok(token_data.claims)
277    }
278
279    /// Map jsonwebtoken errors to AuthError.
280    fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
281        match e.kind() {
282            jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
283            jsonwebtoken::errors::ErrorKind::InvalidSignature => {
284                AuthError::InvalidToken("Invalid signature".to_string())
285            }
286            jsonwebtoken::errors::ErrorKind::InvalidToken => {
287                AuthError::InvalidToken("Invalid token format".to_string())
288            }
289            jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
290                AuthError::InvalidToken(format!("Missing required claim: {}", claim))
291            }
292            jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
293                AuthError::InvalidToken("Invalid issuer".to_string())
294            }
295            jsonwebtoken::errors::ErrorKind::InvalidAudience => {
296                AuthError::InvalidToken("Invalid audience".to_string())
297            }
298            _ => AuthError::InvalidToken(e.to_string()),
299        }
300    }
301
302    /// Decode JWT token without signature verification (DEV MODE ONLY).
303    fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
304        let token_data =
305            dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
306                jsonwebtoken::errors::ErrorKind::InvalidToken => {
307                    AuthError::InvalidToken("Invalid token format".to_string())
308                }
309                _ => AuthError::InvalidToken(e.to_string()),
310            })?;
311
312        // Still check expiration in dev mode
313        if token_data.claims.is_expired() {
314            return Err(AuthError::TokenExpired);
315        }
316
317        Ok(token_data.claims)
318    }
319}
320
321/// Authentication errors.
322#[derive(Debug, Clone, thiserror::Error)]
323pub enum AuthError {
324    #[error("Missing authorization header")]
325    MissingHeader,
326    #[error("Invalid authorization header format")]
327    InvalidHeader,
328    #[error("Invalid token: {0}")]
329    InvalidToken(String),
330    #[error("Token expired")]
331    TokenExpired,
332}
333
334/// Extract token from request headers.
335pub fn extract_token(req: &Request<Body>) -> Result<Option<String>, AuthError> {
336    let Some(header_value) = req.headers().get(axum::http::header::AUTHORIZATION) else {
337        return Ok(None);
338    };
339
340    let header = header_value
341        .to_str()
342        .map_err(|_| AuthError::InvalidHeader)?;
343    let token = header
344        .strip_prefix("Bearer ")
345        .ok_or(AuthError::InvalidHeader)?
346        .trim();
347
348    if token.is_empty() {
349        return Err(AuthError::InvalidHeader);
350    }
351
352    Ok(Some(token.to_string()))
353}
354
355/// Extract auth context from token (async, supports both HMAC and RSA/JWKS).
356pub async fn extract_auth_context_async(
357    token: Option<String>,
358    middleware: &AuthMiddleware,
359) -> Result<AuthContext, AuthError> {
360    match token {
361        Some(token) => middleware
362            .validate_token_async(&token)
363            .await
364            .map(build_auth_context_from_claims),
365        None => Ok(AuthContext::unauthenticated()),
366    }
367}
368
369/// Build auth context from validated claims.
370///
371/// This handles both UUID and non-UUID subjects properly:
372/// - UUID subjects: uses `authenticated()` with the parsed UUID
373/// - Non-UUID subjects: uses `authenticated_without_uuid()` and stores raw subject in claims
374pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
375    // Try to parse subject as UUID first (before moving claims)
376    let user_id = claims.user_id();
377
378    // Build custom claims with raw subject included
379    let mut custom_claims = claims.custom;
380    custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
381
382    match user_id {
383        Some(uuid) => {
384            // Subject is a valid UUID
385            AuthContext::authenticated(uuid, claims.roles, custom_claims)
386        }
387        None => {
388            // Subject is not a UUID (e.g., Firebase uid, Clerk user_xxx, email)
389            // Still authenticated, but user_id() will return None
390            AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
391        }
392    }
393}
394
395/// Authentication middleware function.
396pub async fn auth_middleware(
397    State(middleware): State<Arc<AuthMiddleware>>,
398    req: Request<Body>,
399    next: Next,
400) -> Response {
401    let token = match extract_token(&req) {
402        Ok(token) => token,
403        Err(e) => {
404            tracing::warn!(error = %e, "Invalid authorization header");
405            return (
406                StatusCode::UNAUTHORIZED,
407                Json(serde_json::json!({
408                    "success": false,
409                    "error": { "code": "UNAUTHORIZED", "message": "Invalid authorization header" }
410                })),
411            )
412                .into_response();
413        }
414    };
415    tracing::trace!(
416        token_present = token.is_some(),
417        "Auth middleware processing request"
418    );
419
420    let auth_context = match extract_auth_context_async(token, &middleware).await {
421        Ok(auth_context) => auth_context,
422        Err(e) => {
423            tracing::warn!(error = %e, "Token validation failed");
424            return (
425                StatusCode::UNAUTHORIZED,
426                Json(serde_json::json!({
427                    "success": false,
428                    "error": { "code": "UNAUTHORIZED", "message": "Invalid authentication token" }
429                })),
430            )
431                .into_response();
432        }
433    };
434    tracing::trace!(
435        authenticated = auth_context.is_authenticated(),
436        "Auth context created"
437    );
438
439    let mut req = req;
440    req.extensions_mut().insert(auth_context);
441
442    next.run(req).await
443}
444
445#[cfg(test)]
446#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
447mod tests {
448    use super::*;
449    use jsonwebtoken::{EncodingKey, Header, encode};
450
451    fn create_test_claims(expired: bool) -> Claims {
452        use forge_core::auth::ClaimsBuilder;
453
454        let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
455
456        if expired {
457            builder = builder.duration_secs(-3600); // Expired 1 hour ago
458        } else {
459            builder = builder.duration_secs(3600); // Valid for 1 hour
460        }
461
462        builder.build().unwrap()
463    }
464
465    fn create_test_token(claims: &Claims, secret: &str) -> String {
466        encode(
467            &Header::default(),
468            claims,
469            &EncodingKey::from_secret(secret.as_bytes()),
470        )
471        .unwrap()
472    }
473
474    #[test]
475    fn test_auth_config_default() {
476        let config = AuthConfig::default();
477        assert_eq!(config.algorithm, JwtAlgorithm::HS256);
478        assert!(!config.skip_verification);
479    }
480
481    #[test]
482    fn test_auth_config_dev_mode() {
483        let config = AuthConfig::dev_mode();
484        assert!(config.skip_verification);
485    }
486
487    #[test]
488    fn test_auth_middleware_permissive() {
489        let middleware = AuthMiddleware::permissive();
490        assert!(middleware.config.skip_verification);
491    }
492
493    #[tokio::test]
494    async fn test_valid_token_with_correct_secret() {
495        let secret = "test-secret-key";
496        let config = AuthConfig::with_secret(secret);
497        let middleware = AuthMiddleware::new(config);
498
499        let claims = create_test_claims(false);
500        let token = create_test_token(&claims, secret);
501
502        let result = middleware.validate_token_async(&token).await;
503        assert!(result.is_ok());
504        let validated_claims = result.unwrap();
505        assert_eq!(validated_claims.sub, "test-user-id");
506    }
507
508    #[tokio::test]
509    async fn test_valid_token_with_wrong_secret() {
510        let config = AuthConfig::with_secret("correct-secret");
511        let middleware = AuthMiddleware::new(config);
512
513        let claims = create_test_claims(false);
514        let token = create_test_token(&claims, "wrong-secret");
515
516        let result = middleware.validate_token_async(&token).await;
517        assert!(result.is_err());
518        match result {
519            Err(AuthError::InvalidToken(_)) => {}
520            _ => panic!("Expected InvalidToken error"),
521        }
522    }
523
524    #[tokio::test]
525    async fn test_expired_token() {
526        let secret = "test-secret";
527        let config = AuthConfig::with_secret(secret);
528        let middleware = AuthMiddleware::new(config);
529
530        let claims = create_test_claims(true); // Expired
531        let token = create_test_token(&claims, secret);
532
533        let result = middleware.validate_token_async(&token).await;
534        assert!(result.is_err());
535        match result {
536            Err(AuthError::TokenExpired) => {}
537            _ => panic!("Expected TokenExpired error"),
538        }
539    }
540
541    #[tokio::test]
542    async fn test_tampered_token() {
543        let secret = "test-secret";
544        let config = AuthConfig::with_secret(secret);
545        let middleware = AuthMiddleware::new(config);
546
547        let claims = create_test_claims(false);
548        let mut token = create_test_token(&claims, secret);
549
550        // Tamper with the token by modifying a character in the signature
551        if let Some(last_char) = token.pop() {
552            let replacement = if last_char == 'a' { 'b' } else { 'a' };
553            token.push(replacement);
554        }
555
556        let result = middleware.validate_token_async(&token).await;
557        assert!(result.is_err());
558    }
559
560    #[tokio::test]
561    async fn test_dev_mode_skips_signature() {
562        let config = AuthConfig::dev_mode();
563        let middleware = AuthMiddleware::new(config);
564
565        // Create token with any secret
566        let claims = create_test_claims(false);
567        let token = create_test_token(&claims, "any-secret");
568
569        // Should still validate in dev mode
570        let result = middleware.validate_token_async(&token).await;
571        assert!(result.is_ok());
572    }
573
574    #[tokio::test]
575    async fn test_dev_mode_still_checks_expiration() {
576        let config = AuthConfig::dev_mode();
577        let middleware = AuthMiddleware::new(config);
578
579        let claims = create_test_claims(true); // Expired
580        let token = create_test_token(&claims, "any-secret");
581
582        let result = middleware.validate_token_async(&token).await;
583        assert!(result.is_err());
584        match result {
585            Err(AuthError::TokenExpired) => {}
586            _ => panic!("Expected TokenExpired error even in dev mode"),
587        }
588    }
589
590    #[tokio::test]
591    async fn test_invalid_token_format() {
592        let config = AuthConfig::with_secret("secret");
593        let middleware = AuthMiddleware::new(config);
594
595        let result = middleware.validate_token_async("not-a-valid-jwt").await;
596        assert!(result.is_err());
597        match result {
598            Err(AuthError::InvalidToken(_)) => {}
599            _ => panic!("Expected InvalidToken error"),
600        }
601    }
602
603    #[test]
604    fn test_algorithm_conversion() {
605        // HMAC algorithms
606        assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
607        assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
608        assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
609        // RSA algorithms
610        assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
611        assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
612        assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
613    }
614
615    #[test]
616    fn test_is_hmac_and_is_rsa() {
617        let hmac_config = AuthConfig::with_secret("test");
618        assert!(hmac_config.is_hmac());
619        assert!(!hmac_config.is_rsa());
620
621        let rsa_config = AuthConfig {
622            algorithm: JwtAlgorithm::RS256,
623            ..Default::default()
624        };
625        assert!(!rsa_config.is_hmac());
626        assert!(rsa_config.is_rsa());
627    }
628
629    #[test]
630    fn test_extract_token_rejects_non_bearer_header() {
631        let req = Request::builder()
632            .header(axum::http::header::AUTHORIZATION, "Basic abc")
633            .body(Body::empty())
634            .unwrap();
635
636        let result = extract_token(&req);
637        assert!(matches!(result, Err(AuthError::InvalidHeader)));
638    }
639
640    #[tokio::test]
641    async fn test_extract_auth_context_async_invalid_token_errors() {
642        let middleware = AuthMiddleware::new(AuthConfig::with_secret("secret"));
643        let result = extract_auth_context_async(Some("bad.token".to_string()), &middleware).await;
644        assert!(matches!(result, Err(AuthError::InvalidToken(_))));
645    }
646}