1use std::sync::Arc;
2
3use axum::{
4 body::Body,
5 extract::{Request, State},
6 middleware::Next,
7 response::Response,
8};
9use forge_core::auth::Claims;
10use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
11use forge_core::function::AuthContext;
12use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode};
13use tracing::debug;
14
15use super::jwks::JwksClient;
16
17#[derive(Debug, Clone)]
19pub struct AuthConfig {
20 pub jwt_secret: Option<String>,
22 pub algorithm: JwtAlgorithm,
24 pub jwks_client: Option<Arc<JwksClient>>,
26 pub issuer: Option<String>,
28 pub audience: Option<String>,
30 pub allow_anonymous: bool,
32 pub skip_verification: bool,
34}
35
36impl Default for AuthConfig {
37 fn default() -> Self {
38 Self {
39 jwt_secret: None,
40 algorithm: JwtAlgorithm::HS256,
41 jwks_client: None,
42 issuer: None,
43 audience: None,
44 allow_anonymous: true,
45 skip_verification: false,
46 }
47 }
48}
49
50impl AuthConfig {
51 pub fn from_forge_config(config: &forge_core::config::AuthConfig) -> Self {
53 let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
54
55 let jwks_client = config
56 .jwks_url
57 .as_ref()
58 .map(|url| Arc::new(JwksClient::new(url.clone(), config.jwks_cache_ttl_secs)));
59
60 Self {
61 jwt_secret: config.jwt_secret.clone(),
62 algorithm,
63 jwks_client,
64 issuer: config.jwt_issuer.clone(),
65 audience: config.jwt_audience.clone(),
66 allow_anonymous: config.allow_anonymous,
67 skip_verification: false,
68 }
69 }
70
71 pub fn with_secret(secret: impl Into<String>) -> Self {
73 Self {
74 jwt_secret: Some(secret.into()),
75 ..Default::default()
76 }
77 }
78
79 pub fn dev_mode() -> Self {
82 Self {
83 jwt_secret: None,
84 algorithm: JwtAlgorithm::HS256,
85 jwks_client: None,
86 issuer: None,
87 audience: None,
88 allow_anonymous: true,
89 skip_verification: true,
90 }
91 }
92
93 pub fn is_hmac(&self) -> bool {
95 matches!(
96 self.algorithm,
97 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
98 )
99 }
100
101 pub fn is_rsa(&self) -> bool {
103 matches!(
104 self.algorithm,
105 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
106 )
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
112pub enum JwtAlgorithm {
113 #[default]
114 HS256,
115 HS384,
116 HS512,
117 RS256,
118 RS384,
119 RS512,
120}
121
122impl From<JwtAlgorithm> for Algorithm {
123 fn from(alg: JwtAlgorithm) -> Self {
124 match alg {
125 JwtAlgorithm::HS256 => Algorithm::HS256,
126 JwtAlgorithm::HS384 => Algorithm::HS384,
127 JwtAlgorithm::HS512 => Algorithm::HS512,
128 JwtAlgorithm::RS256 => Algorithm::RS256,
129 JwtAlgorithm::RS384 => Algorithm::RS384,
130 JwtAlgorithm::RS512 => Algorithm::RS512,
131 }
132 }
133}
134
135impl From<CoreJwtAlgorithm> for JwtAlgorithm {
136 fn from(alg: CoreJwtAlgorithm) -> Self {
137 match alg {
138 CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
139 CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
140 CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
141 CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
142 CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
143 CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
144 }
145 }
146}
147
148#[derive(Clone)]
150pub struct AuthMiddleware {
151 config: Arc<AuthConfig>,
152 hmac_key: Option<DecodingKey>,
154}
155
156impl std::fmt::Debug for AuthMiddleware {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 f.debug_struct("AuthMiddleware")
159 .field("config", &self.config)
160 .field("hmac_key", &self.hmac_key.is_some())
161 .finish()
162 }
163}
164
165impl AuthMiddleware {
166 pub fn new(config: AuthConfig) -> Self {
168 let hmac_key = if config.skip_verification {
170 None
171 } else if config.is_hmac() {
172 config
173 .jwt_secret
174 .as_ref()
175 .filter(|s| !s.is_empty())
176 .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
177 } else {
178 None
179 };
180
181 Self {
182 config: Arc::new(config),
183 hmac_key,
184 }
185 }
186
187 pub fn permissive() -> Self {
190 Self::new(AuthConfig::dev_mode())
191 }
192
193 pub fn config(&self) -> &AuthConfig {
195 &self.config
196 }
197
198 pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
200 if self.config.skip_verification {
201 return self.decode_without_verification(token);
202 }
203
204 if self.config.is_hmac() {
205 self.validate_hmac(token)
206 } else {
207 self.validate_rsa(token).await
208 }
209 }
210
211 fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
213 let key = self.hmac_key.as_ref().ok_or_else(|| {
214 AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
215 })?;
216
217 self.decode_and_validate(token, key)
218 }
219
220 async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
222 let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
223 AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
224 })?;
225
226 let header = jsonwebtoken::decode_header(token)
228 .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
229
230 debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
231
232 let key = if let Some(kid) = header.kid {
234 jwks.get_key(&kid).await.map_err(|e| {
235 AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
236 })?
237 } else {
238 jwks.get_any_key()
239 .await
240 .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
241 };
242
243 self.decode_and_validate(token, &key)
244 }
245
246 fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
248 let mut validation = Validation::new(self.config.algorithm.into());
249
250 validation.validate_exp = true;
252 validation.validate_nbf = false;
253 validation.leeway = 60; validation.set_required_spec_claims(&["exp", "sub"]);
257
258 if let Some(ref issuer) = self.config.issuer {
260 validation.set_issuer(&[issuer]);
261 }
262
263 if let Some(ref audience) = self.config.audience {
265 validation.set_audience(&[audience]);
266 } else {
267 validation.validate_aud = false;
268 }
269
270 let token_data =
271 decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
272
273 Ok(token_data.claims)
274 }
275
276 fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
278 match e.kind() {
279 jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
280 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
281 AuthError::InvalidToken("Invalid signature".to_string())
282 }
283 jsonwebtoken::errors::ErrorKind::InvalidToken => {
284 AuthError::InvalidToken("Invalid token format".to_string())
285 }
286 jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
287 AuthError::InvalidToken(format!("Missing required claim: {}", claim))
288 }
289 jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
290 AuthError::InvalidToken("Invalid issuer".to_string())
291 }
292 jsonwebtoken::errors::ErrorKind::InvalidAudience => {
293 AuthError::InvalidToken("Invalid audience".to_string())
294 }
295 _ => AuthError::InvalidToken(e.to_string()),
296 }
297 }
298
299 fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
301 let token_data =
302 dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
303 jsonwebtoken::errors::ErrorKind::InvalidToken => {
304 AuthError::InvalidToken("Invalid token format".to_string())
305 }
306 _ => AuthError::InvalidToken(e.to_string()),
307 })?;
308
309 if token_data.claims.is_expired() {
311 return Err(AuthError::TokenExpired);
312 }
313
314 Ok(token_data.claims)
315 }
316}
317
318#[derive(Debug, Clone, thiserror::Error)]
320pub enum AuthError {
321 #[error("Missing authorization header")]
322 MissingHeader,
323 #[error("Invalid authorization header format")]
324 InvalidHeader,
325 #[error("Invalid token: {0}")]
326 InvalidToken(String),
327 #[error("Token expired")]
328 TokenExpired,
329}
330
331pub fn extract_token(req: &Request<Body>) -> Option<String> {
333 req.headers()
334 .get(axum::http::header::AUTHORIZATION)
335 .and_then(|v| v.to_str().ok())
336 .filter(|header| header.starts_with("Bearer "))
337 .map(|header| header.trim_start_matches("Bearer ").trim().to_string())
338}
339
340pub async fn extract_auth_context_async(
342 token: Option<String>,
343 middleware: &AuthMiddleware,
344) -> AuthContext {
345 match token {
346 Some(token) => match middleware.validate_token_async(&token).await {
347 Ok(claims) => build_auth_context_from_claims(claims),
348 Err(e) => {
349 tracing::warn!(error = %e, "Token validation failed");
350 AuthContext::unauthenticated()
351 }
352 },
353 None => AuthContext::unauthenticated(),
354 }
355}
356
357pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
363 let user_id = claims.user_id();
365
366 let mut custom_claims = claims.custom;
368 custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
369
370 match user_id {
371 Some(uuid) => {
372 AuthContext::authenticated(uuid, claims.roles, custom_claims)
374 }
375 None => {
376 AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
379 }
380 }
381}
382
383pub async fn auth_middleware(
385 State(middleware): State<Arc<AuthMiddleware>>,
386 req: Request<Body>,
387 next: Next,
388) -> Response {
389 let token = extract_token(&req);
390 tracing::trace!(
391 token_present = token.is_some(),
392 "Auth middleware processing request"
393 );
394
395 let auth_context = extract_auth_context_async(token, &middleware).await;
396 tracing::trace!(
397 authenticated = auth_context.is_authenticated(),
398 "Auth context created"
399 );
400
401 let mut req = req;
402 req.extensions_mut().insert(auth_context);
403
404 next.run(req).await
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use jsonwebtoken::{EncodingKey, Header, encode};
411
412 fn create_test_claims(expired: bool) -> Claims {
413 use forge_core::auth::ClaimsBuilder;
414
415 let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
416
417 if expired {
418 builder = builder.duration_secs(-3600); } else {
420 builder = builder.duration_secs(3600); }
422
423 builder.build().unwrap()
424 }
425
426 fn create_test_token(claims: &Claims, secret: &str) -> String {
427 encode(
428 &Header::default(),
429 claims,
430 &EncodingKey::from_secret(secret.as_bytes()),
431 )
432 .unwrap()
433 }
434
435 #[test]
436 fn test_auth_config_default() {
437 let config = AuthConfig::default();
438 assert!(config.allow_anonymous);
439 assert_eq!(config.algorithm, JwtAlgorithm::HS256);
440 assert!(!config.skip_verification);
441 }
442
443 #[test]
444 fn test_auth_config_dev_mode() {
445 let config = AuthConfig::dev_mode();
446 assert!(config.skip_verification);
447 assert!(config.allow_anonymous);
448 }
449
450 #[test]
451 fn test_auth_middleware_permissive() {
452 let middleware = AuthMiddleware::permissive();
453 assert!(middleware.config.skip_verification);
454 }
455
456 #[tokio::test]
457 async fn test_valid_token_with_correct_secret() {
458 let secret = "test-secret-key";
459 let config = AuthConfig::with_secret(secret);
460 let middleware = AuthMiddleware::new(config);
461
462 let claims = create_test_claims(false);
463 let token = create_test_token(&claims, secret);
464
465 let result = middleware.validate_token_async(&token).await;
466 assert!(result.is_ok());
467 let validated_claims = result.unwrap();
468 assert_eq!(validated_claims.sub, "test-user-id");
469 }
470
471 #[tokio::test]
472 async fn test_valid_token_with_wrong_secret() {
473 let config = AuthConfig::with_secret("correct-secret");
474 let middleware = AuthMiddleware::new(config);
475
476 let claims = create_test_claims(false);
477 let token = create_test_token(&claims, "wrong-secret");
478
479 let result = middleware.validate_token_async(&token).await;
480 assert!(result.is_err());
481 match result {
482 Err(AuthError::InvalidToken(_)) => {}
483 _ => panic!("Expected InvalidToken error"),
484 }
485 }
486
487 #[tokio::test]
488 async fn test_expired_token() {
489 let secret = "test-secret";
490 let config = AuthConfig::with_secret(secret);
491 let middleware = AuthMiddleware::new(config);
492
493 let claims = create_test_claims(true); let token = create_test_token(&claims, secret);
495
496 let result = middleware.validate_token_async(&token).await;
497 assert!(result.is_err());
498 match result {
499 Err(AuthError::TokenExpired) => {}
500 _ => panic!("Expected TokenExpired error"),
501 }
502 }
503
504 #[tokio::test]
505 async fn test_tampered_token() {
506 let secret = "test-secret";
507 let config = AuthConfig::with_secret(secret);
508 let middleware = AuthMiddleware::new(config);
509
510 let claims = create_test_claims(false);
511 let mut token = create_test_token(&claims, secret);
512
513 if let Some(last_char) = token.pop() {
515 let replacement = if last_char == 'a' { 'b' } else { 'a' };
516 token.push(replacement);
517 }
518
519 let result = middleware.validate_token_async(&token).await;
520 assert!(result.is_err());
521 }
522
523 #[tokio::test]
524 async fn test_dev_mode_skips_signature() {
525 let config = AuthConfig::dev_mode();
526 let middleware = AuthMiddleware::new(config);
527
528 let claims = create_test_claims(false);
530 let token = create_test_token(&claims, "any-secret");
531
532 let result = middleware.validate_token_async(&token).await;
534 assert!(result.is_ok());
535 }
536
537 #[tokio::test]
538 async fn test_dev_mode_still_checks_expiration() {
539 let config = AuthConfig::dev_mode();
540 let middleware = AuthMiddleware::new(config);
541
542 let claims = create_test_claims(true); let token = create_test_token(&claims, "any-secret");
544
545 let result = middleware.validate_token_async(&token).await;
546 assert!(result.is_err());
547 match result {
548 Err(AuthError::TokenExpired) => {}
549 _ => panic!("Expected TokenExpired error even in dev mode"),
550 }
551 }
552
553 #[tokio::test]
554 async fn test_invalid_token_format() {
555 let config = AuthConfig::with_secret("secret");
556 let middleware = AuthMiddleware::new(config);
557
558 let result = middleware.validate_token_async("not-a-valid-jwt").await;
559 assert!(result.is_err());
560 match result {
561 Err(AuthError::InvalidToken(_)) => {}
562 _ => panic!("Expected InvalidToken error"),
563 }
564 }
565
566 #[test]
567 fn test_algorithm_conversion() {
568 assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
570 assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
571 assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
572 assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
574 assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
575 assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
576 }
577
578 #[test]
579 fn test_is_hmac_and_is_rsa() {
580 let hmac_config = AuthConfig::with_secret("test");
581 assert!(hmac_config.is_hmac());
582 assert!(!hmac_config.is_rsa());
583
584 let rsa_config = AuthConfig {
585 algorithm: JwtAlgorithm::RS256,
586 ..Default::default()
587 };
588 assert!(!rsa_config.is_hmac());
589 assert!(rsa_config.is_rsa());
590 }
591}