forge_runtime/gateway/
auth.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::{
5    body::Body,
6    extract::{Request, State},
7    middleware::Next,
8    response::Response,
9};
10use forge_core::auth::Claims;
11use forge_core::function::AuthContext;
12use jsonwebtoken::{dangerous, decode, Algorithm, DecodingKey, Validation};
13use uuid::Uuid;
14
15/// Authentication configuration.
16#[derive(Debug, Clone)]
17pub struct AuthConfig {
18    /// JWT secret for HMAC signing.
19    pub jwt_secret: String,
20    /// JWT algorithm (HS256, HS384, HS512).
21    pub algorithm: JwtAlgorithm,
22    /// Whether to allow unauthenticated requests.
23    pub allow_anonymous: bool,
24    /// Skip signature verification (DEV MODE ONLY - NEVER USE IN PRODUCTION).
25    /// This allows testing with any JWT token without a valid signature.
26    pub skip_verification: bool,
27}
28
29impl Default for AuthConfig {
30    fn default() -> Self {
31        Self {
32            jwt_secret: String::new(),
33            algorithm: JwtAlgorithm::HS256,
34            allow_anonymous: true,
35            skip_verification: false,
36        }
37    }
38}
39
40impl AuthConfig {
41    /// Create a new auth config with the given secret.
42    pub fn with_secret(secret: impl Into<String>) -> Self {
43        Self {
44            jwt_secret: secret.into(),
45            ..Default::default()
46        }
47    }
48
49    /// Create a dev mode config that skips signature verification.
50    /// WARNING: Only use this for development and testing!
51    pub fn dev_mode() -> Self {
52        Self {
53            jwt_secret: String::new(),
54            algorithm: JwtAlgorithm::HS256,
55            allow_anonymous: true,
56            skip_verification: true,
57        }
58    }
59}
60
61/// Supported JWT algorithms.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
63pub enum JwtAlgorithm {
64    #[default]
65    HS256,
66    HS384,
67    HS512,
68}
69
70impl From<JwtAlgorithm> for Algorithm {
71    fn from(alg: JwtAlgorithm) -> Self {
72        match alg {
73            JwtAlgorithm::HS256 => Algorithm::HS256,
74            JwtAlgorithm::HS384 => Algorithm::HS384,
75            JwtAlgorithm::HS512 => Algorithm::HS512,
76        }
77    }
78}
79
80/// Authentication middleware.
81#[derive(Clone)]
82pub struct AuthMiddleware {
83    config: Arc<AuthConfig>,
84    decoding_key: Option<DecodingKey>,
85}
86
87impl std::fmt::Debug for AuthMiddleware {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("AuthMiddleware")
90            .field("config", &self.config)
91            .field("decoding_key", &self.decoding_key.is_some())
92            .finish()
93    }
94}
95
96impl AuthMiddleware {
97    /// Create a new auth middleware.
98    pub fn new(config: AuthConfig) -> Self {
99        let decoding_key = if config.skip_verification || config.jwt_secret.is_empty() {
100            None
101        } else {
102            Some(DecodingKey::from_secret(config.jwt_secret.as_bytes()))
103        };
104
105        Self {
106            config: Arc::new(config),
107            decoding_key,
108        }
109    }
110
111    /// Create a middleware that allows all requests (development mode).
112    /// WARNING: This skips signature verification! Never use in production.
113    pub fn permissive() -> Self {
114        Self::new(AuthConfig::dev_mode())
115    }
116
117    /// Get the config.
118    pub fn config(&self) -> &AuthConfig {
119        &self.config
120    }
121
122    /// Validate a JWT token and extract claims.
123    pub fn validate_token(&self, token: &str) -> Result<Claims, AuthError> {
124        if self.config.skip_verification {
125            // DEV MODE: Skip signature verification
126            self.decode_without_verification(token)
127        } else if let Some(ref key) = self.decoding_key {
128            self.decode_with_verification(token, key)
129        } else {
130            Err(AuthError::InvalidToken(
131                "JWT secret not configured".to_string(),
132            ))
133        }
134    }
135
136    /// Decode and verify JWT token using jsonwebtoken crate.
137    fn decode_with_verification(
138        &self,
139        token: &str,
140        key: &DecodingKey,
141    ) -> Result<Claims, AuthError> {
142        let mut validation = Validation::new(self.config.algorithm.into());
143
144        // Configure validation
145        validation.validate_exp = true;
146        validation.validate_nbf = false;
147        validation.validate_aud = false;
148        validation.leeway = 60; // 60 seconds clock skew tolerance
149
150        // Require exp claim
151        validation.set_required_spec_claims(&["exp", "sub"]);
152
153        let token_data = decode::<Claims>(token, key, &validation).map_err(|e| match e.kind() {
154            jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
155            jsonwebtoken::errors::ErrorKind::InvalidSignature => {
156                AuthError::InvalidToken("Invalid signature".to_string())
157            }
158            jsonwebtoken::errors::ErrorKind::InvalidToken => {
159                AuthError::InvalidToken("Invalid token format".to_string())
160            }
161            jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
162                AuthError::InvalidToken(format!("Missing required claim: {}", claim))
163            }
164            _ => AuthError::InvalidToken(e.to_string()),
165        })?;
166
167        Ok(token_data.claims)
168    }
169
170    /// Decode JWT token without signature verification (DEV MODE ONLY).
171    /// This parses the token structure but does not validate the signature.
172    fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
173        let token_data =
174            dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
175                jsonwebtoken::errors::ErrorKind::InvalidToken => {
176                    AuthError::InvalidToken("Invalid token format".to_string())
177                }
178                _ => AuthError::InvalidToken(e.to_string()),
179            })?;
180
181        // Still check expiration in dev mode
182        if token_data.claims.is_expired() {
183            return Err(AuthError::TokenExpired);
184        }
185
186        Ok(token_data.claims)
187    }
188}
189
190/// Authentication errors.
191#[derive(Debug, Clone, thiserror::Error)]
192pub enum AuthError {
193    #[error("Missing authorization header")]
194    MissingHeader,
195    #[error("Invalid authorization header format")]
196    InvalidHeader,
197    #[error("Invalid token: {0}")]
198    InvalidToken(String),
199    #[error("Token expired")]
200    TokenExpired,
201}
202
203/// Extract auth context from request.
204pub fn extract_auth_context(req: &Request<Body>, middleware: &AuthMiddleware) -> AuthContext {
205    // Try to extract Authorization header
206    let auth_header = req
207        .headers()
208        .get(axum::http::header::AUTHORIZATION)
209        .and_then(|v| v.to_str().ok());
210
211    let token = match auth_header {
212        Some(header) if header.starts_with("Bearer ") => {
213            Some(header.trim_start_matches("Bearer ").trim())
214        }
215        _ => None,
216    };
217
218    match token {
219        Some(token) => match middleware.validate_token(token) {
220            Ok(claims) => {
221                let user_id = claims.user_id().unwrap_or_else(Uuid::nil);
222                let custom_claims: HashMap<String, serde_json::Value> = claims.custom;
223                AuthContext::authenticated(user_id, claims.roles, custom_claims)
224            }
225            Err(_) => AuthContext::unauthenticated(),
226        },
227        None => AuthContext::unauthenticated(),
228    }
229}
230
231/// Authentication middleware function.
232pub async fn auth_middleware(
233    State(middleware): State<Arc<AuthMiddleware>>,
234    req: Request<Body>,
235    next: Next,
236) -> Response {
237    let auth_context = extract_auth_context(&req, &middleware);
238
239    // Store auth context in request extensions
240    let mut req = req;
241    req.extensions_mut().insert(auth_context);
242
243    next.run(req).await
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use jsonwebtoken::{encode, EncodingKey, Header};
250
251    fn create_test_claims(expired: bool) -> Claims {
252        use forge_core::auth::ClaimsBuilder;
253
254        let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
255
256        if expired {
257            builder = builder.duration_secs(-3600); // Expired 1 hour ago
258        } else {
259            builder = builder.duration_secs(3600); // Valid for 1 hour
260        }
261
262        builder.build().unwrap()
263    }
264
265    fn create_test_token(claims: &Claims, secret: &str) -> String {
266        encode(
267            &Header::default(),
268            claims,
269            &EncodingKey::from_secret(secret.as_bytes()),
270        )
271        .unwrap()
272    }
273
274    #[test]
275    fn test_auth_config_default() {
276        let config = AuthConfig::default();
277        assert!(config.allow_anonymous);
278        assert_eq!(config.algorithm, JwtAlgorithm::HS256);
279        assert!(!config.skip_verification);
280    }
281
282    #[test]
283    fn test_auth_config_dev_mode() {
284        let config = AuthConfig::dev_mode();
285        assert!(config.skip_verification);
286        assert!(config.allow_anonymous);
287    }
288
289    #[test]
290    fn test_auth_middleware_permissive() {
291        let middleware = AuthMiddleware::permissive();
292        assert!(middleware.config.skip_verification);
293    }
294
295    #[test]
296    fn test_valid_token_with_correct_secret() {
297        let secret = "test-secret-key";
298        let config = AuthConfig::with_secret(secret);
299        let middleware = AuthMiddleware::new(config);
300
301        let claims = create_test_claims(false);
302        let token = create_test_token(&claims, secret);
303
304        let result = middleware.validate_token(&token);
305        assert!(result.is_ok());
306        let validated_claims = result.unwrap();
307        assert_eq!(validated_claims.sub, "test-user-id");
308    }
309
310    #[test]
311    fn test_valid_token_with_wrong_secret() {
312        let config = AuthConfig::with_secret("correct-secret");
313        let middleware = AuthMiddleware::new(config);
314
315        let claims = create_test_claims(false);
316        let token = create_test_token(&claims, "wrong-secret");
317
318        let result = middleware.validate_token(&token);
319        assert!(result.is_err());
320        match result {
321            Err(AuthError::InvalidToken(_)) => {}
322            _ => panic!("Expected InvalidToken error"),
323        }
324    }
325
326    #[test]
327    fn test_expired_token() {
328        let secret = "test-secret";
329        let config = AuthConfig::with_secret(secret);
330        let middleware = AuthMiddleware::new(config);
331
332        let claims = create_test_claims(true); // Expired
333        let token = create_test_token(&claims, secret);
334
335        let result = middleware.validate_token(&token);
336        assert!(result.is_err());
337        match result {
338            Err(AuthError::TokenExpired) => {}
339            _ => panic!("Expected TokenExpired error"),
340        }
341    }
342
343    #[test]
344    fn test_tampered_token() {
345        let secret = "test-secret";
346        let config = AuthConfig::with_secret(secret);
347        let middleware = AuthMiddleware::new(config);
348
349        let claims = create_test_claims(false);
350        let mut token = create_test_token(&claims, secret);
351
352        // Tamper with the token by modifying a character in the signature
353        if let Some(last_char) = token.pop() {
354            let replacement = if last_char == 'a' { 'b' } else { 'a' };
355            token.push(replacement);
356        }
357
358        let result = middleware.validate_token(&token);
359        assert!(result.is_err());
360    }
361
362    #[test]
363    fn test_dev_mode_skips_signature() {
364        let config = AuthConfig::dev_mode();
365        let middleware = AuthMiddleware::new(config);
366
367        // Create token with any secret
368        let claims = create_test_claims(false);
369        let token = create_test_token(&claims, "any-secret");
370
371        // Should still validate in dev mode
372        let result = middleware.validate_token(&token);
373        assert!(result.is_ok());
374    }
375
376    #[test]
377    fn test_dev_mode_still_checks_expiration() {
378        let config = AuthConfig::dev_mode();
379        let middleware = AuthMiddleware::new(config);
380
381        let claims = create_test_claims(true); // Expired
382        let token = create_test_token(&claims, "any-secret");
383
384        let result = middleware.validate_token(&token);
385        assert!(result.is_err());
386        match result {
387            Err(AuthError::TokenExpired) => {}
388            _ => panic!("Expected TokenExpired error even in dev mode"),
389        }
390    }
391
392    #[test]
393    fn test_invalid_token_format() {
394        let config = AuthConfig::with_secret("secret");
395        let middleware = AuthMiddleware::new(config);
396
397        let result = middleware.validate_token("not-a-valid-jwt");
398        assert!(result.is_err());
399        match result {
400            Err(AuthError::InvalidToken(_)) => {}
401            _ => panic!("Expected InvalidToken error"),
402        }
403    }
404
405    #[test]
406    fn test_algorithm_conversion() {
407        assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
408        assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
409        assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
410    }
411}