1use std::sync::Arc;
2
3use axum::{
4 body::Body,
5 extract::{Request, State},
6 http::StatusCode,
7 middleware::Next,
8 response::{IntoResponse, Json, Response},
9};
10use forge_core::auth::Claims;
11use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
12use forge_core::function::AuthContext;
13use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode};
14use tracing::debug;
15
16use super::jwks::JwksClient;
17
18#[derive(Debug, Clone)]
20pub struct AuthConfig {
21 pub jwt_secret: Option<String>,
23 pub algorithm: JwtAlgorithm,
25 pub jwks_client: Option<Arc<JwksClient>>,
27 pub issuer: Option<String>,
29 pub audience: Option<String>,
31 pub skip_verification: bool,
33}
34
35impl Default for AuthConfig {
36 fn default() -> Self {
37 Self {
38 jwt_secret: None,
39 algorithm: JwtAlgorithm::HS256,
40 jwks_client: None,
41 issuer: None,
42 audience: None,
43 skip_verification: false,
44 }
45 }
46}
47
48impl AuthConfig {
49 pub fn from_forge_config(
51 config: &forge_core::config::AuthConfig,
52 ) -> Result<Self, super::jwks::JwksError> {
53 let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
54
55 let jwks_client = config
56 .jwks_url
57 .as_ref()
58 .map(|url| JwksClient::new(url.clone(), config.jwks_cache_ttl_secs).map(Arc::new))
59 .transpose()?;
60
61 Ok(Self {
62 jwt_secret: config.jwt_secret.clone(),
63 algorithm,
64 jwks_client,
65 issuer: config.jwt_issuer.clone(),
66 audience: config.jwt_audience.clone(),
67 skip_verification: false,
68 })
69 }
70
71 pub fn with_secret(secret: impl Into<String>) -> Self {
73 Self {
74 jwt_secret: Some(secret.into()),
75 ..Default::default()
76 }
77 }
78
79 pub fn dev_mode() -> Self {
82 Self {
83 jwt_secret: None,
84 algorithm: JwtAlgorithm::HS256,
85 jwks_client: None,
86 issuer: None,
87 audience: None,
88 skip_verification: true,
89 }
90 }
91
92 pub fn is_hmac(&self) -> bool {
94 matches!(
95 self.algorithm,
96 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
97 )
98 }
99
100 pub fn is_rsa(&self) -> bool {
102 matches!(
103 self.algorithm,
104 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
105 )
106 }
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub enum JwtAlgorithm {
112 #[default]
113 HS256,
114 HS384,
115 HS512,
116 RS256,
117 RS384,
118 RS512,
119}
120
121impl From<JwtAlgorithm> for Algorithm {
122 fn from(alg: JwtAlgorithm) -> Self {
123 match alg {
124 JwtAlgorithm::HS256 => Algorithm::HS256,
125 JwtAlgorithm::HS384 => Algorithm::HS384,
126 JwtAlgorithm::HS512 => Algorithm::HS512,
127 JwtAlgorithm::RS256 => Algorithm::RS256,
128 JwtAlgorithm::RS384 => Algorithm::RS384,
129 JwtAlgorithm::RS512 => Algorithm::RS512,
130 }
131 }
132}
133
134impl From<CoreJwtAlgorithm> for JwtAlgorithm {
135 fn from(alg: CoreJwtAlgorithm) -> Self {
136 match alg {
137 CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
138 CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
139 CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
140 CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
141 CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
142 CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
143 }
144 }
145}
146
147#[derive(Clone)]
149pub struct AuthMiddleware {
150 config: Arc<AuthConfig>,
151 hmac_key: Option<DecodingKey>,
153}
154
155impl std::fmt::Debug for AuthMiddleware {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 f.debug_struct("AuthMiddleware")
158 .field("config", &self.config)
159 .field("hmac_key", &self.hmac_key.is_some())
160 .finish()
161 }
162}
163
164impl AuthMiddleware {
165 pub fn new(config: AuthConfig) -> Self {
167 if config.skip_verification {
168 tracing::warn!("JWT signature verification is DISABLED. Do not use in production.");
169 }
170
171 let hmac_key = if config.skip_verification {
173 None
174 } else if config.is_hmac() {
175 config
176 .jwt_secret
177 .as_ref()
178 .filter(|s| !s.is_empty())
179 .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
180 } else {
181 None
182 };
183
184 Self {
185 config: Arc::new(config),
186 hmac_key,
187 }
188 }
189
190 pub fn permissive() -> Self {
193 Self::new(AuthConfig::dev_mode())
194 }
195
196 pub fn config(&self) -> &AuthConfig {
198 &self.config
199 }
200
201 pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
203 if self.config.skip_verification {
204 return self.decode_without_verification(token);
205 }
206
207 if self.config.is_hmac() {
208 self.validate_hmac(token)
209 } else {
210 self.validate_rsa(token).await
211 }
212 }
213
214 fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
216 let key = self.hmac_key.as_ref().ok_or_else(|| {
217 AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
218 })?;
219
220 self.decode_and_validate(token, key)
221 }
222
223 async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
225 let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
226 AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
227 })?;
228
229 let header = jsonwebtoken::decode_header(token)
231 .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
232
233 debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
234
235 let key = if let Some(kid) = header.kid {
237 jwks.get_key(&kid).await.map_err(|e| {
238 AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
239 })?
240 } else {
241 jwks.get_any_key()
242 .await
243 .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
244 };
245
246 self.decode_and_validate(token, &key)
247 }
248
249 fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
251 let mut validation = Validation::new(self.config.algorithm.into());
252
253 validation.validate_exp = true;
255 validation.validate_nbf = false;
256 validation.leeway = 60; validation.set_required_spec_claims(&["exp", "sub"]);
260
261 if let Some(ref issuer) = self.config.issuer {
263 validation.set_issuer(&[issuer]);
264 }
265
266 if let Some(ref audience) = self.config.audience {
268 validation.set_audience(&[audience]);
269 } else {
270 validation.validate_aud = false;
271 }
272
273 let token_data =
274 decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
275
276 Ok(token_data.claims)
277 }
278
279 fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
281 match e.kind() {
282 jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
283 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
284 AuthError::InvalidToken("Invalid signature".to_string())
285 }
286 jsonwebtoken::errors::ErrorKind::InvalidToken => {
287 AuthError::InvalidToken("Invalid token format".to_string())
288 }
289 jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
290 AuthError::InvalidToken(format!("Missing required claim: {}", claim))
291 }
292 jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
293 AuthError::InvalidToken("Invalid issuer".to_string())
294 }
295 jsonwebtoken::errors::ErrorKind::InvalidAudience => {
296 AuthError::InvalidToken("Invalid audience".to_string())
297 }
298 _ => AuthError::InvalidToken(e.to_string()),
299 }
300 }
301
302 fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
304 let token_data =
305 dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
306 jsonwebtoken::errors::ErrorKind::InvalidToken => {
307 AuthError::InvalidToken("Invalid token format".to_string())
308 }
309 _ => AuthError::InvalidToken(e.to_string()),
310 })?;
311
312 if token_data.claims.is_expired() {
314 return Err(AuthError::TokenExpired);
315 }
316
317 Ok(token_data.claims)
318 }
319}
320
321#[derive(Debug, Clone, thiserror::Error)]
323pub enum AuthError {
324 #[error("Missing authorization header")]
325 MissingHeader,
326 #[error("Invalid authorization header format")]
327 InvalidHeader,
328 #[error("Invalid token: {0}")]
329 InvalidToken(String),
330 #[error("Token expired")]
331 TokenExpired,
332}
333
334pub fn extract_token(req: &Request<Body>) -> Result<Option<String>, AuthError> {
336 let Some(header_value) = req.headers().get(axum::http::header::AUTHORIZATION) else {
337 return Ok(None);
338 };
339
340 let header = header_value
341 .to_str()
342 .map_err(|_| AuthError::InvalidHeader)?;
343 let token = header
344 .strip_prefix("Bearer ")
345 .ok_or(AuthError::InvalidHeader)?
346 .trim();
347
348 if token.is_empty() {
349 return Err(AuthError::InvalidHeader);
350 }
351
352 Ok(Some(token.to_string()))
353}
354
355pub async fn extract_auth_context_async(
357 token: Option<String>,
358 middleware: &AuthMiddleware,
359) -> Result<AuthContext, AuthError> {
360 match token {
361 Some(token) => middleware
362 .validate_token_async(&token)
363 .await
364 .map(build_auth_context_from_claims),
365 None => Ok(AuthContext::unauthenticated()),
366 }
367}
368
369pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
375 let user_id = claims.user_id();
377
378 let mut custom_claims = claims.custom;
380 custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
381
382 match user_id {
383 Some(uuid) => {
384 AuthContext::authenticated(uuid, claims.roles, custom_claims)
386 }
387 None => {
388 AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
391 }
392 }
393}
394
395pub async fn auth_middleware(
397 State(middleware): State<Arc<AuthMiddleware>>,
398 req: Request<Body>,
399 next: Next,
400) -> Response {
401 let token = match extract_token(&req) {
402 Ok(token) => token,
403 Err(e) => {
404 tracing::warn!(error = %e, "Invalid authorization header");
405 return (
406 StatusCode::UNAUTHORIZED,
407 Json(serde_json::json!({
408 "success": false,
409 "error": { "code": "UNAUTHORIZED", "message": "Invalid authorization header" }
410 })),
411 )
412 .into_response();
413 }
414 };
415 tracing::trace!(
416 token_present = token.is_some(),
417 "Auth middleware processing request"
418 );
419
420 let auth_context = match extract_auth_context_async(token, &middleware).await {
421 Ok(auth_context) => auth_context,
422 Err(e) => {
423 tracing::warn!(error = %e, "Token validation failed");
424 return (
425 StatusCode::UNAUTHORIZED,
426 Json(serde_json::json!({
427 "success": false,
428 "error": { "code": "UNAUTHORIZED", "message": "Invalid authentication token" }
429 })),
430 )
431 .into_response();
432 }
433 };
434 tracing::trace!(
435 authenticated = auth_context.is_authenticated(),
436 "Auth context created"
437 );
438
439 let mut req = req;
440 req.extensions_mut().insert(auth_context);
441
442 next.run(req).await
443}
444
445#[cfg(test)]
446#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
447mod tests {
448 use super::*;
449 use jsonwebtoken::{EncodingKey, Header, encode};
450
451 fn create_test_claims(expired: bool) -> Claims {
452 use forge_core::auth::ClaimsBuilder;
453
454 let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
455
456 if expired {
457 builder = builder.duration_secs(-3600); } else {
459 builder = builder.duration_secs(3600); }
461
462 builder.build().unwrap()
463 }
464
465 fn create_test_token(claims: &Claims, secret: &str) -> String {
466 encode(
467 &Header::default(),
468 claims,
469 &EncodingKey::from_secret(secret.as_bytes()),
470 )
471 .unwrap()
472 }
473
474 #[test]
475 fn test_auth_config_default() {
476 let config = AuthConfig::default();
477 assert_eq!(config.algorithm, JwtAlgorithm::HS256);
478 assert!(!config.skip_verification);
479 }
480
481 #[test]
482 fn test_auth_config_dev_mode() {
483 let config = AuthConfig::dev_mode();
484 assert!(config.skip_verification);
485 }
486
487 #[test]
488 fn test_auth_middleware_permissive() {
489 let middleware = AuthMiddleware::permissive();
490 assert!(middleware.config.skip_verification);
491 }
492
493 #[tokio::test]
494 async fn test_valid_token_with_correct_secret() {
495 let secret = "test-secret-key";
496 let config = AuthConfig::with_secret(secret);
497 let middleware = AuthMiddleware::new(config);
498
499 let claims = create_test_claims(false);
500 let token = create_test_token(&claims, secret);
501
502 let result = middleware.validate_token_async(&token).await;
503 assert!(result.is_ok());
504 let validated_claims = result.unwrap();
505 assert_eq!(validated_claims.sub, "test-user-id");
506 }
507
508 #[tokio::test]
509 async fn test_valid_token_with_wrong_secret() {
510 let config = AuthConfig::with_secret("correct-secret");
511 let middleware = AuthMiddleware::new(config);
512
513 let claims = create_test_claims(false);
514 let token = create_test_token(&claims, "wrong-secret");
515
516 let result = middleware.validate_token_async(&token).await;
517 assert!(result.is_err());
518 match result {
519 Err(AuthError::InvalidToken(_)) => {}
520 _ => panic!("Expected InvalidToken error"),
521 }
522 }
523
524 #[tokio::test]
525 async fn test_expired_token() {
526 let secret = "test-secret";
527 let config = AuthConfig::with_secret(secret);
528 let middleware = AuthMiddleware::new(config);
529
530 let claims = create_test_claims(true); let token = create_test_token(&claims, secret);
532
533 let result = middleware.validate_token_async(&token).await;
534 assert!(result.is_err());
535 match result {
536 Err(AuthError::TokenExpired) => {}
537 _ => panic!("Expected TokenExpired error"),
538 }
539 }
540
541 #[tokio::test]
542 async fn test_tampered_token() {
543 let secret = "test-secret";
544 let config = AuthConfig::with_secret(secret);
545 let middleware = AuthMiddleware::new(config);
546
547 let claims = create_test_claims(false);
548 let mut token = create_test_token(&claims, secret);
549
550 if let Some(last_char) = token.pop() {
552 let replacement = if last_char == 'a' { 'b' } else { 'a' };
553 token.push(replacement);
554 }
555
556 let result = middleware.validate_token_async(&token).await;
557 assert!(result.is_err());
558 }
559
560 #[tokio::test]
561 async fn test_dev_mode_skips_signature() {
562 let config = AuthConfig::dev_mode();
563 let middleware = AuthMiddleware::new(config);
564
565 let claims = create_test_claims(false);
567 let token = create_test_token(&claims, "any-secret");
568
569 let result = middleware.validate_token_async(&token).await;
571 assert!(result.is_ok());
572 }
573
574 #[tokio::test]
575 async fn test_dev_mode_still_checks_expiration() {
576 let config = AuthConfig::dev_mode();
577 let middleware = AuthMiddleware::new(config);
578
579 let claims = create_test_claims(true); let token = create_test_token(&claims, "any-secret");
581
582 let result = middleware.validate_token_async(&token).await;
583 assert!(result.is_err());
584 match result {
585 Err(AuthError::TokenExpired) => {}
586 _ => panic!("Expected TokenExpired error even in dev mode"),
587 }
588 }
589
590 #[tokio::test]
591 async fn test_invalid_token_format() {
592 let config = AuthConfig::with_secret("secret");
593 let middleware = AuthMiddleware::new(config);
594
595 let result = middleware.validate_token_async("not-a-valid-jwt").await;
596 assert!(result.is_err());
597 match result {
598 Err(AuthError::InvalidToken(_)) => {}
599 _ => panic!("Expected InvalidToken error"),
600 }
601 }
602
603 #[test]
604 fn test_algorithm_conversion() {
605 assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
607 assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
608 assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
609 assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
611 assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
612 assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
613 }
614
615 #[test]
616 fn test_is_hmac_and_is_rsa() {
617 let hmac_config = AuthConfig::with_secret("test");
618 assert!(hmac_config.is_hmac());
619 assert!(!hmac_config.is_rsa());
620
621 let rsa_config = AuthConfig {
622 algorithm: JwtAlgorithm::RS256,
623 ..Default::default()
624 };
625 assert!(!rsa_config.is_hmac());
626 assert!(rsa_config.is_rsa());
627 }
628
629 #[test]
630 fn test_extract_token_rejects_non_bearer_header() {
631 let req = Request::builder()
632 .header(axum::http::header::AUTHORIZATION, "Basic abc")
633 .body(Body::empty())
634 .unwrap();
635
636 let result = extract_token(&req);
637 assert!(matches!(result, Err(AuthError::InvalidHeader)));
638 }
639
640 #[tokio::test]
641 async fn test_extract_auth_context_async_invalid_token_errors() {
642 let middleware = AuthMiddleware::new(AuthConfig::with_secret("secret"));
643 let result = extract_auth_context_async(Some("bad.token".to_string()), &middleware).await;
644 assert!(matches!(result, Err(AuthError::InvalidToken(_))));
645 }
646}