1use std::sync::Arc;
2
3use axum::{
4 body::Body,
5 extract::{Request, State},
6 http::{StatusCode, header},
7 middleware::Next,
8 response::{IntoResponse, Json, Response},
9};
10use forge_core::auth::Claims;
11use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
12use forge_core::function::AuthContext;
13use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode, encode};
14use tracing::debug;
15
16use super::jwks::JwksClient;
17
18#[derive(Debug, Clone)]
20pub struct AuthConfig {
21 pub jwt_secret: Option<String>,
23 pub algorithm: JwtAlgorithm,
25 pub jwks_client: Option<Arc<JwksClient>>,
27 pub issuer: Option<String>,
29 pub audience: Option<String>,
31 pub skip_verification: bool,
33}
34
35impl Default for AuthConfig {
36 fn default() -> Self {
37 Self {
38 jwt_secret: None,
39 algorithm: JwtAlgorithm::HS256,
40 jwks_client: None,
41 issuer: None,
42 audience: None,
43 skip_verification: false,
44 }
45 }
46}
47
48impl AuthConfig {
49 pub fn from_forge_config(
51 config: &forge_core::config::AuthConfig,
52 ) -> Result<Self, super::jwks::JwksError> {
53 let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
54
55 let jwks_client = config
56 .jwks_url
57 .as_ref()
58 .map(|url| JwksClient::new(url.clone(), config.jwks_cache_ttl_secs).map(Arc::new))
59 .transpose()?;
60
61 Ok(Self {
62 jwt_secret: config.jwt_secret.clone(),
63 algorithm,
64 jwks_client,
65 issuer: config.jwt_issuer.clone(),
66 audience: config.jwt_audience.clone(),
67 skip_verification: false,
68 })
69 }
70
71 pub fn with_secret(secret: impl Into<String>) -> Self {
73 Self {
74 jwt_secret: Some(secret.into()),
75 ..Default::default()
76 }
77 }
78
79 pub fn dev_mode() -> Self {
82 Self {
83 jwt_secret: None,
84 algorithm: JwtAlgorithm::HS256,
85 jwks_client: None,
86 issuer: None,
87 audience: None,
88 skip_verification: true,
89 }
90 }
91
92 pub fn is_hmac(&self) -> bool {
94 matches!(
95 self.algorithm,
96 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
97 )
98 }
99
100 pub fn is_rsa(&self) -> bool {
102 matches!(
103 self.algorithm,
104 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
105 )
106 }
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub enum JwtAlgorithm {
112 #[default]
113 HS256,
114 HS384,
115 HS512,
116 RS256,
117 RS384,
118 RS512,
119}
120
121impl From<JwtAlgorithm> for Algorithm {
122 fn from(alg: JwtAlgorithm) -> Self {
123 match alg {
124 JwtAlgorithm::HS256 => Algorithm::HS256,
125 JwtAlgorithm::HS384 => Algorithm::HS384,
126 JwtAlgorithm::HS512 => Algorithm::HS512,
127 JwtAlgorithm::RS256 => Algorithm::RS256,
128 JwtAlgorithm::RS384 => Algorithm::RS384,
129 JwtAlgorithm::RS512 => Algorithm::RS512,
130 }
131 }
132}
133
134impl From<CoreJwtAlgorithm> for JwtAlgorithm {
135 fn from(alg: CoreJwtAlgorithm) -> Self {
136 match alg {
137 CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
138 CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
139 CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
140 CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
141 CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
142 CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
143 }
144 }
145}
146
147#[derive(Clone)]
152pub struct HmacTokenIssuer {
153 secret: String,
154 algorithm: Algorithm,
155}
156
157impl HmacTokenIssuer {
158 pub fn from_config(config: &AuthConfig) -> Option<Self> {
160 if !config.is_hmac() {
161 return None;
162 }
163 let secret = config.jwt_secret.as_ref()?.clone();
164 if secret.is_empty() {
165 return None;
166 }
167 Some(Self {
168 secret,
169 algorithm: config.algorithm.into(),
170 })
171 }
172}
173
174impl forge_core::TokenIssuer for HmacTokenIssuer {
175 fn sign(&self, claims: &Claims) -> forge_core::Result<String> {
176 let header = jsonwebtoken::Header::new(self.algorithm);
177 encode(
178 &header,
179 claims,
180 &jsonwebtoken::EncodingKey::from_secret(self.secret.as_bytes()),
181 )
182 .map_err(|e| forge_core::ForgeError::Internal(format!("token signing error: {e}")))
183 }
184}
185
186#[derive(Clone)]
188pub struct AuthMiddleware {
189 config: Arc<AuthConfig>,
190 hmac_key: Option<DecodingKey>,
192}
193
194impl std::fmt::Debug for AuthMiddleware {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 f.debug_struct("AuthMiddleware")
197 .field("config", &self.config)
198 .field("hmac_key", &self.hmac_key.is_some())
199 .finish()
200 }
201}
202
203impl AuthMiddleware {
204 pub fn new(config: AuthConfig) -> Self {
206 if config.skip_verification {
207 tracing::warn!("JWT signature verification is DISABLED. Do not use in production.");
208 }
209
210 let hmac_key = if config.skip_verification {
212 None
213 } else if config.is_hmac() {
214 config
215 .jwt_secret
216 .as_ref()
217 .filter(|s| !s.is_empty())
218 .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
219 } else {
220 None
221 };
222
223 Self {
224 config: Arc::new(config),
225 hmac_key,
226 }
227 }
228
229 pub fn permissive() -> Self {
232 Self::new(AuthConfig::dev_mode())
233 }
234
235 pub fn config(&self) -> &AuthConfig {
237 &self.config
238 }
239
240 pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
242 if self.config.skip_verification {
243 return self.decode_without_verification(token);
244 }
245
246 if self.config.is_hmac() {
247 self.validate_hmac(token)
248 } else {
249 self.validate_rsa(token).await
250 }
251 }
252
253 fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
255 let key = self.hmac_key.as_ref().ok_or_else(|| {
256 AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
257 })?;
258
259 self.decode_and_validate(token, key)
260 }
261
262 async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
264 let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
265 AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
266 })?;
267
268 let header = jsonwebtoken::decode_header(token)
270 .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
271
272 debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
273
274 let key = if let Some(kid) = header.kid {
276 jwks.get_key(&kid).await.map_err(|e| {
277 AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
278 })?
279 } else {
280 jwks.get_any_key()
281 .await
282 .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
283 };
284
285 self.decode_and_validate(token, &key)
286 }
287
288 fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
290 let mut validation = Validation::new(self.config.algorithm.into());
291
292 validation.validate_exp = true;
294 validation.validate_nbf = false;
295 validation.leeway = 60; validation.set_required_spec_claims(&["exp", "sub"]);
299
300 if let Some(ref issuer) = self.config.issuer {
302 validation.set_issuer(&[issuer]);
303 }
304
305 if let Some(ref audience) = self.config.audience {
307 validation.set_audience(&[audience]);
308 } else {
309 validation.validate_aud = false;
310 }
311
312 let token_data =
313 decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
314
315 Ok(token_data.claims)
316 }
317
318 fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
320 match e.kind() {
321 jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
322 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
323 AuthError::InvalidToken("Invalid signature".to_string())
324 }
325 jsonwebtoken::errors::ErrorKind::InvalidToken => {
326 AuthError::InvalidToken("Invalid token format".to_string())
327 }
328 jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
329 AuthError::InvalidToken(format!("Missing required claim: {}", claim))
330 }
331 jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
332 AuthError::InvalidToken("Invalid issuer".to_string())
333 }
334 jsonwebtoken::errors::ErrorKind::InvalidAudience => {
335 AuthError::InvalidToken("Invalid audience".to_string())
336 }
337 _ => AuthError::InvalidToken(e.to_string()),
338 }
339 }
340
341 fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
343 let token_data =
344 dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
345 jsonwebtoken::errors::ErrorKind::InvalidToken => {
346 AuthError::InvalidToken("Invalid token format".to_string())
347 }
348 _ => AuthError::InvalidToken(e.to_string()),
349 })?;
350
351 if token_data.claims.is_expired() {
353 return Err(AuthError::TokenExpired);
354 }
355
356 Ok(token_data.claims)
357 }
358}
359
360#[derive(Debug, Clone, thiserror::Error)]
362pub enum AuthError {
363 #[error("Missing authorization header")]
364 MissingHeader,
365 #[error("Invalid authorization header format")]
366 InvalidHeader,
367 #[error("Invalid token: {0}")]
368 InvalidToken(String),
369 #[error("Token expired")]
370 TokenExpired,
371}
372
373pub fn extract_token(req: &Request<Body>) -> Result<Option<String>, AuthError> {
375 let Some(header_value) = req.headers().get(axum::http::header::AUTHORIZATION) else {
376 return Ok(None);
377 };
378
379 let header = header_value
380 .to_str()
381 .map_err(|_| AuthError::InvalidHeader)?;
382 let token = header
383 .strip_prefix("Bearer ")
384 .ok_or(AuthError::InvalidHeader)?
385 .trim();
386
387 if token.is_empty() {
388 return Err(AuthError::InvalidHeader);
389 }
390
391 Ok(Some(token.to_string()))
392}
393
394pub async fn extract_auth_context_async(
396 token: Option<String>,
397 middleware: &AuthMiddleware,
398) -> Result<AuthContext, AuthError> {
399 match token {
400 Some(token) => middleware
401 .validate_token_async(&token)
402 .await
403 .map(build_auth_context_from_claims),
404 None => Ok(AuthContext::unauthenticated()),
405 }
406}
407
408pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
414 let user_id = claims.user_id();
416
417 let mut custom_claims = claims.custom;
419 custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
420
421 match user_id {
422 Some(uuid) => {
423 AuthContext::authenticated(uuid, claims.roles, custom_claims)
425 }
426 None => {
427 AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
430 }
431 }
432}
433
434pub async fn auth_middleware(
436 State(middleware): State<Arc<AuthMiddleware>>,
437 req: Request<Body>,
438 next: Next,
439) -> Response {
440 let token = match extract_token(&req) {
441 Ok(token) => token,
442 Err(e) => {
443 tracing::warn!(error = %e, "Invalid authorization header");
444 return (
445 StatusCode::UNAUTHORIZED,
446 Json(serde_json::json!({
447 "success": false,
448 "error": { "code": "UNAUTHORIZED", "message": "Invalid authorization header" }
449 })),
450 )
451 .into_response();
452 }
453 };
454 tracing::trace!(
455 token_present = token.is_some(),
456 "Auth middleware processing request"
457 );
458
459 let auth_context = match extract_auth_context_async(token, &middleware).await {
460 Ok(auth_context) => auth_context,
461 Err(e) => {
462 tracing::warn!(error = %e, "Token validation failed");
463 return (
464 StatusCode::UNAUTHORIZED,
465 Json(serde_json::json!({
466 "success": false,
467 "error": { "code": "UNAUTHORIZED", "message": "Invalid authentication token" }
468 })),
469 )
470 .into_response();
471 }
472 };
473 tracing::trace!(
474 authenticated = auth_context.is_authenticated(),
475 "Auth context created"
476 );
477
478 let should_set_cookie =
484 auth_context.is_authenticated() && middleware.config.jwt_secret.is_some();
485
486 let req_is_https = req
487 .headers()
488 .get("x-forwarded-proto")
489 .and_then(|v| v.to_str().ok())
490 .map(|s| s == "https")
491 .unwrap_or(false);
492
493 let has_session_cookie = req
495 .headers()
496 .get(header::COOKIE)
497 .and_then(|v| v.to_str().ok())
498 .map(|c| c.contains("forge_session="))
499 .unwrap_or(false);
500
501 let should_set_cookie = should_set_cookie && !has_session_cookie;
502
503 let mut req = req;
504 req.extensions_mut().insert(auth_context.clone());
505
506 let mut response = next.run(req).await;
507
508 if should_set_cookie
509 && let Some(subject) = auth_context.subject()
510 && let Some(secret) = &middleware.config.jwt_secret
511 {
512 let cookie_value = sign_session_cookie(subject, secret);
513 let secure_flag = if req_is_https { "; Secure" } else { "" };
514 let cookie = format!(
515 "forge_session={cookie_value}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=86400{secure_flag}"
516 );
517 if let Ok(val) = axum::http::HeaderValue::from_str(&cookie) {
518 response.headers_mut().append(header::SET_COOKIE, val);
519 }
520 }
521
522 response
523}
524
525pub fn sign_session_cookie(subject: &str, secret: &str) -> String {
529 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
530 use hmac::{Hmac, Mac};
531 use sha2::Sha256;
532
533 let expiry = chrono::Utc::now().timestamp() + 86400; let payload = format!("{subject}.{expiry}");
535
536 let mut mac =
537 Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
538 mac.update(payload.as_bytes());
539 let sig = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
540
541 format!("{payload}.{sig}")
542}
543
544pub fn verify_session_cookie(cookie_value: &str, secret: &str) -> Option<String> {
547 use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
548 use hmac::{Hmac, Mac};
549 use sha2::Sha256;
550
551 let parts: Vec<&str> = cookie_value.rsplitn(2, '.').collect();
552 if parts.len() != 2 {
553 return None;
554 }
555 let sig_encoded = parts.first()?;
556 let payload = parts.get(1)?; let sig_bytes = URL_SAFE_NO_PAD.decode(sig_encoded).ok()?;
560 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).ok()?;
561 mac.update(payload.as_bytes());
562 mac.verify_slice(&sig_bytes).ok()?;
563
564 let dot_pos = payload.rfind('.')?;
566 let subject = &payload[..dot_pos];
567 let expiry_str = &payload[dot_pos + 1..];
568 let expiry: i64 = expiry_str.parse().ok()?;
569
570 if chrono::Utc::now().timestamp() > expiry {
571 return None;
572 }
573
574 Some(subject.to_string())
575}
576
577#[cfg(test)]
578#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
579mod tests {
580 use super::*;
581 use jsonwebtoken::{EncodingKey, Header, encode};
582
583 fn create_test_claims(expired: bool) -> Claims {
584 use forge_core::auth::ClaimsBuilder;
585
586 let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
587
588 if expired {
589 builder = builder.duration_secs(-3600); } else {
591 builder = builder.duration_secs(3600); }
593
594 builder.build().unwrap()
595 }
596
597 fn create_test_token(claims: &Claims, secret: &str) -> String {
598 encode(
599 &Header::default(),
600 claims,
601 &EncodingKey::from_secret(secret.as_bytes()),
602 )
603 .unwrap()
604 }
605
606 #[test]
607 fn test_auth_config_default() {
608 let config = AuthConfig::default();
609 assert_eq!(config.algorithm, JwtAlgorithm::HS256);
610 assert!(!config.skip_verification);
611 }
612
613 #[test]
614 fn test_auth_config_dev_mode() {
615 let config = AuthConfig::dev_mode();
616 assert!(config.skip_verification);
617 }
618
619 #[test]
620 fn test_auth_middleware_permissive() {
621 let middleware = AuthMiddleware::permissive();
622 assert!(middleware.config.skip_verification);
623 }
624
625 #[tokio::test]
626 async fn test_valid_token_with_correct_secret() {
627 let secret = "test-secret-key";
628 let config = AuthConfig::with_secret(secret);
629 let middleware = AuthMiddleware::new(config);
630
631 let claims = create_test_claims(false);
632 let token = create_test_token(&claims, secret);
633
634 let result = middleware.validate_token_async(&token).await;
635 assert!(result.is_ok());
636 let validated_claims = result.unwrap();
637 assert_eq!(validated_claims.sub, "test-user-id");
638 }
639
640 #[tokio::test]
641 async fn test_valid_token_with_wrong_secret() {
642 let config = AuthConfig::with_secret("correct-secret");
643 let middleware = AuthMiddleware::new(config);
644
645 let claims = create_test_claims(false);
646 let token = create_test_token(&claims, "wrong-secret");
647
648 let result = middleware.validate_token_async(&token).await;
649 assert!(result.is_err());
650 match result {
651 Err(AuthError::InvalidToken(_)) => {}
652 _ => panic!("Expected InvalidToken error"),
653 }
654 }
655
656 #[tokio::test]
657 async fn test_expired_token() {
658 let secret = "test-secret";
659 let config = AuthConfig::with_secret(secret);
660 let middleware = AuthMiddleware::new(config);
661
662 let claims = create_test_claims(true); let token = create_test_token(&claims, secret);
664
665 let result = middleware.validate_token_async(&token).await;
666 assert!(result.is_err());
667 match result {
668 Err(AuthError::TokenExpired) => {}
669 _ => panic!("Expected TokenExpired error"),
670 }
671 }
672
673 #[tokio::test]
674 async fn test_tampered_token() {
675 let secret = "test-secret";
676 let config = AuthConfig::with_secret(secret);
677 let middleware = AuthMiddleware::new(config);
678
679 let claims = create_test_claims(false);
680 let mut token = create_test_token(&claims, secret);
681
682 if let Some(last_char) = token.pop() {
684 let replacement = if last_char == 'a' { 'b' } else { 'a' };
685 token.push(replacement);
686 }
687
688 let result = middleware.validate_token_async(&token).await;
689 assert!(result.is_err());
690 }
691
692 #[tokio::test]
693 async fn test_dev_mode_skips_signature() {
694 let config = AuthConfig::dev_mode();
695 let middleware = AuthMiddleware::new(config);
696
697 let claims = create_test_claims(false);
699 let token = create_test_token(&claims, "any-secret");
700
701 let result = middleware.validate_token_async(&token).await;
703 assert!(result.is_ok());
704 }
705
706 #[tokio::test]
707 async fn test_dev_mode_still_checks_expiration() {
708 let config = AuthConfig::dev_mode();
709 let middleware = AuthMiddleware::new(config);
710
711 let claims = create_test_claims(true); let token = create_test_token(&claims, "any-secret");
713
714 let result = middleware.validate_token_async(&token).await;
715 assert!(result.is_err());
716 match result {
717 Err(AuthError::TokenExpired) => {}
718 _ => panic!("Expected TokenExpired error even in dev mode"),
719 }
720 }
721
722 #[tokio::test]
723 async fn test_invalid_token_format() {
724 let config = AuthConfig::with_secret("secret");
725 let middleware = AuthMiddleware::new(config);
726
727 let result = middleware.validate_token_async("not-a-valid-jwt").await;
728 assert!(result.is_err());
729 match result {
730 Err(AuthError::InvalidToken(_)) => {}
731 _ => panic!("Expected InvalidToken error"),
732 }
733 }
734
735 #[test]
736 fn test_algorithm_conversion() {
737 assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
739 assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
740 assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
741 assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
743 assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
744 assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
745 }
746
747 #[test]
748 fn test_is_hmac_and_is_rsa() {
749 let hmac_config = AuthConfig::with_secret("test");
750 assert!(hmac_config.is_hmac());
751 assert!(!hmac_config.is_rsa());
752
753 let rsa_config = AuthConfig {
754 algorithm: JwtAlgorithm::RS256,
755 ..Default::default()
756 };
757 assert!(!rsa_config.is_hmac());
758 assert!(rsa_config.is_rsa());
759 }
760
761 #[test]
762 fn test_extract_token_rejects_non_bearer_header() {
763 let req = Request::builder()
764 .header(axum::http::header::AUTHORIZATION, "Basic abc")
765 .body(Body::empty())
766 .unwrap();
767
768 let result = extract_token(&req);
769 assert!(matches!(result, Err(AuthError::InvalidHeader)));
770 }
771
772 #[tokio::test]
773 async fn test_extract_auth_context_async_invalid_token_errors() {
774 let middleware = AuthMiddleware::new(AuthConfig::with_secret("secret"));
775 let result = extract_auth_context_async(Some("bad.token".to_string()), &middleware).await;
776 assert!(matches!(result, Err(AuthError::InvalidToken(_))));
777 }
778}