1use std::sync::Arc;
2
3use axum::{
4 body::Body,
5 extract::{Request, State},
6 http::StatusCode,
7 middleware::Next,
8 response::{IntoResponse, Response},
9};
10use forge_core::auth::Claims;
11use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
12use forge_core::function::AuthContext;
13use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode};
14use tracing::debug;
15
16use super::jwks::JwksClient;
17
18#[derive(Debug, Clone)]
20pub struct AuthConfig {
21 pub jwt_secret: Option<String>,
23 pub algorithm: JwtAlgorithm,
25 pub jwks_client: Option<Arc<JwksClient>>,
27 pub issuer: Option<String>,
29 pub audience: Option<String>,
31 pub skip_verification: bool,
33}
34
35impl Default for AuthConfig {
36 fn default() -> Self {
37 Self {
38 jwt_secret: None,
39 algorithm: JwtAlgorithm::HS256,
40 jwks_client: None,
41 issuer: None,
42 audience: None,
43 skip_verification: false,
44 }
45 }
46}
47
48impl AuthConfig {
49 pub fn from_forge_config(config: &forge_core::config::AuthConfig) -> Self {
51 let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
52
53 let jwks_client = config
54 .jwks_url
55 .as_ref()
56 .map(|url| Arc::new(JwksClient::new(url.clone(), config.jwks_cache_ttl_secs)));
57
58 Self {
59 jwt_secret: config.jwt_secret.clone(),
60 algorithm,
61 jwks_client,
62 issuer: config.jwt_issuer.clone(),
63 audience: config.jwt_audience.clone(),
64 skip_verification: false,
65 }
66 }
67
68 pub fn with_secret(secret: impl Into<String>) -> Self {
70 Self {
71 jwt_secret: Some(secret.into()),
72 ..Default::default()
73 }
74 }
75
76 pub fn dev_mode() -> Self {
79 Self {
80 jwt_secret: None,
81 algorithm: JwtAlgorithm::HS256,
82 jwks_client: None,
83 issuer: None,
84 audience: None,
85 skip_verification: true,
86 }
87 }
88
89 pub fn is_hmac(&self) -> bool {
91 matches!(
92 self.algorithm,
93 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
94 )
95 }
96
97 pub fn is_rsa(&self) -> bool {
99 matches!(
100 self.algorithm,
101 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
102 )
103 }
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
108pub enum JwtAlgorithm {
109 #[default]
110 HS256,
111 HS384,
112 HS512,
113 RS256,
114 RS384,
115 RS512,
116}
117
118impl From<JwtAlgorithm> for Algorithm {
119 fn from(alg: JwtAlgorithm) -> Self {
120 match alg {
121 JwtAlgorithm::HS256 => Algorithm::HS256,
122 JwtAlgorithm::HS384 => Algorithm::HS384,
123 JwtAlgorithm::HS512 => Algorithm::HS512,
124 JwtAlgorithm::RS256 => Algorithm::RS256,
125 JwtAlgorithm::RS384 => Algorithm::RS384,
126 JwtAlgorithm::RS512 => Algorithm::RS512,
127 }
128 }
129}
130
131impl From<CoreJwtAlgorithm> for JwtAlgorithm {
132 fn from(alg: CoreJwtAlgorithm) -> Self {
133 match alg {
134 CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
135 CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
136 CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
137 CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
138 CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
139 CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
140 }
141 }
142}
143
144#[derive(Clone)]
146pub struct AuthMiddleware {
147 config: Arc<AuthConfig>,
148 hmac_key: Option<DecodingKey>,
150}
151
152impl std::fmt::Debug for AuthMiddleware {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 f.debug_struct("AuthMiddleware")
155 .field("config", &self.config)
156 .field("hmac_key", &self.hmac_key.is_some())
157 .finish()
158 }
159}
160
161impl AuthMiddleware {
162 pub fn new(config: AuthConfig) -> Self {
164 let hmac_key = if config.skip_verification {
166 None
167 } else if config.is_hmac() {
168 config
169 .jwt_secret
170 .as_ref()
171 .filter(|s| !s.is_empty())
172 .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
173 } else {
174 None
175 };
176
177 Self {
178 config: Arc::new(config),
179 hmac_key,
180 }
181 }
182
183 pub fn permissive() -> Self {
186 Self::new(AuthConfig::dev_mode())
187 }
188
189 pub fn config(&self) -> &AuthConfig {
191 &self.config
192 }
193
194 pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
196 if self.config.skip_verification {
197 return self.decode_without_verification(token);
198 }
199
200 if self.config.is_hmac() {
201 self.validate_hmac(token)
202 } else {
203 self.validate_rsa(token).await
204 }
205 }
206
207 fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
209 let key = self.hmac_key.as_ref().ok_or_else(|| {
210 AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
211 })?;
212
213 self.decode_and_validate(token, key)
214 }
215
216 async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
218 let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
219 AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
220 })?;
221
222 let header = jsonwebtoken::decode_header(token)
224 .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
225
226 debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
227
228 let key = if let Some(kid) = header.kid {
230 jwks.get_key(&kid).await.map_err(|e| {
231 AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
232 })?
233 } else {
234 jwks.get_any_key()
235 .await
236 .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
237 };
238
239 self.decode_and_validate(token, &key)
240 }
241
242 fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
244 let mut validation = Validation::new(self.config.algorithm.into());
245
246 validation.validate_exp = true;
248 validation.validate_nbf = false;
249 validation.leeway = 60; validation.set_required_spec_claims(&["exp", "sub"]);
253
254 if let Some(ref issuer) = self.config.issuer {
256 validation.set_issuer(&[issuer]);
257 }
258
259 if let Some(ref audience) = self.config.audience {
261 validation.set_audience(&[audience]);
262 } else {
263 validation.validate_aud = false;
264 }
265
266 let token_data =
267 decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
268
269 Ok(token_data.claims)
270 }
271
272 fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
274 match e.kind() {
275 jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
276 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
277 AuthError::InvalidToken("Invalid signature".to_string())
278 }
279 jsonwebtoken::errors::ErrorKind::InvalidToken => {
280 AuthError::InvalidToken("Invalid token format".to_string())
281 }
282 jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
283 AuthError::InvalidToken(format!("Missing required claim: {}", claim))
284 }
285 jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
286 AuthError::InvalidToken("Invalid issuer".to_string())
287 }
288 jsonwebtoken::errors::ErrorKind::InvalidAudience => {
289 AuthError::InvalidToken("Invalid audience".to_string())
290 }
291 _ => AuthError::InvalidToken(e.to_string()),
292 }
293 }
294
295 fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
297 let token_data =
298 dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
299 jsonwebtoken::errors::ErrorKind::InvalidToken => {
300 AuthError::InvalidToken("Invalid token format".to_string())
301 }
302 _ => AuthError::InvalidToken(e.to_string()),
303 })?;
304
305 if token_data.claims.is_expired() {
307 return Err(AuthError::TokenExpired);
308 }
309
310 Ok(token_data.claims)
311 }
312}
313
314#[derive(Debug, Clone, thiserror::Error)]
316pub enum AuthError {
317 #[error("Missing authorization header")]
318 MissingHeader,
319 #[error("Invalid authorization header format")]
320 InvalidHeader,
321 #[error("Invalid token: {0}")]
322 InvalidToken(String),
323 #[error("Token expired")]
324 TokenExpired,
325}
326
327pub fn extract_token(req: &Request<Body>) -> Result<Option<String>, AuthError> {
329 let Some(header_value) = req.headers().get(axum::http::header::AUTHORIZATION) else {
330 return Ok(None);
331 };
332
333 let header = header_value
334 .to_str()
335 .map_err(|_| AuthError::InvalidHeader)?;
336 let token = header
337 .strip_prefix("Bearer ")
338 .ok_or(AuthError::InvalidHeader)?
339 .trim();
340
341 if token.is_empty() {
342 return Err(AuthError::InvalidHeader);
343 }
344
345 Ok(Some(token.to_string()))
346}
347
348pub async fn extract_auth_context_async(
350 token: Option<String>,
351 middleware: &AuthMiddleware,
352) -> Result<AuthContext, AuthError> {
353 match token {
354 Some(token) => middleware
355 .validate_token_async(&token)
356 .await
357 .map(build_auth_context_from_claims),
358 None => Ok(AuthContext::unauthenticated()),
359 }
360}
361
362pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
368 let user_id = claims.user_id();
370
371 let mut custom_claims = claims.custom;
373 custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
374
375 match user_id {
376 Some(uuid) => {
377 AuthContext::authenticated(uuid, claims.roles, custom_claims)
379 }
380 None => {
381 AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
384 }
385 }
386}
387
388pub async fn auth_middleware(
390 State(middleware): State<Arc<AuthMiddleware>>,
391 req: Request<Body>,
392 next: Next,
393) -> Response {
394 let token = match extract_token(&req) {
395 Ok(token) => token,
396 Err(e) => {
397 tracing::warn!(error = %e, "Invalid authorization header");
398 return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response();
399 }
400 };
401 tracing::trace!(
402 token_present = token.is_some(),
403 "Auth middleware processing request"
404 );
405
406 let auth_context = match extract_auth_context_async(token, &middleware).await {
407 Ok(auth_context) => auth_context,
408 Err(e) => {
409 tracing::warn!(error = %e, "Token validation failed");
410 return (StatusCode::UNAUTHORIZED, "Invalid authentication token").into_response();
411 }
412 };
413 tracing::trace!(
414 authenticated = auth_context.is_authenticated(),
415 "Auth context created"
416 );
417
418 let mut req = req;
419 req.extensions_mut().insert(auth_context);
420
421 next.run(req).await
422}
423
424#[cfg(test)]
425#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
426mod tests {
427 use super::*;
428 use jsonwebtoken::{EncodingKey, Header, encode};
429
430 fn create_test_claims(expired: bool) -> Claims {
431 use forge_core::auth::ClaimsBuilder;
432
433 let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
434
435 if expired {
436 builder = builder.duration_secs(-3600); } else {
438 builder = builder.duration_secs(3600); }
440
441 builder.build().unwrap()
442 }
443
444 fn create_test_token(claims: &Claims, secret: &str) -> String {
445 encode(
446 &Header::default(),
447 claims,
448 &EncodingKey::from_secret(secret.as_bytes()),
449 )
450 .unwrap()
451 }
452
453 #[test]
454 fn test_auth_config_default() {
455 let config = AuthConfig::default();
456 assert_eq!(config.algorithm, JwtAlgorithm::HS256);
457 assert!(!config.skip_verification);
458 }
459
460 #[test]
461 fn test_auth_config_dev_mode() {
462 let config = AuthConfig::dev_mode();
463 assert!(config.skip_verification);
464 }
465
466 #[test]
467 fn test_auth_middleware_permissive() {
468 let middleware = AuthMiddleware::permissive();
469 assert!(middleware.config.skip_verification);
470 }
471
472 #[tokio::test]
473 async fn test_valid_token_with_correct_secret() {
474 let secret = "test-secret-key";
475 let config = AuthConfig::with_secret(secret);
476 let middleware = AuthMiddleware::new(config);
477
478 let claims = create_test_claims(false);
479 let token = create_test_token(&claims, secret);
480
481 let result = middleware.validate_token_async(&token).await;
482 assert!(result.is_ok());
483 let validated_claims = result.unwrap();
484 assert_eq!(validated_claims.sub, "test-user-id");
485 }
486
487 #[tokio::test]
488 async fn test_valid_token_with_wrong_secret() {
489 let config = AuthConfig::with_secret("correct-secret");
490 let middleware = AuthMiddleware::new(config);
491
492 let claims = create_test_claims(false);
493 let token = create_test_token(&claims, "wrong-secret");
494
495 let result = middleware.validate_token_async(&token).await;
496 assert!(result.is_err());
497 match result {
498 Err(AuthError::InvalidToken(_)) => {}
499 _ => panic!("Expected InvalidToken error"),
500 }
501 }
502
503 #[tokio::test]
504 async fn test_expired_token() {
505 let secret = "test-secret";
506 let config = AuthConfig::with_secret(secret);
507 let middleware = AuthMiddleware::new(config);
508
509 let claims = create_test_claims(true); let token = create_test_token(&claims, secret);
511
512 let result = middleware.validate_token_async(&token).await;
513 assert!(result.is_err());
514 match result {
515 Err(AuthError::TokenExpired) => {}
516 _ => panic!("Expected TokenExpired error"),
517 }
518 }
519
520 #[tokio::test]
521 async fn test_tampered_token() {
522 let secret = "test-secret";
523 let config = AuthConfig::with_secret(secret);
524 let middleware = AuthMiddleware::new(config);
525
526 let claims = create_test_claims(false);
527 let mut token = create_test_token(&claims, secret);
528
529 if let Some(last_char) = token.pop() {
531 let replacement = if last_char == 'a' { 'b' } else { 'a' };
532 token.push(replacement);
533 }
534
535 let result = middleware.validate_token_async(&token).await;
536 assert!(result.is_err());
537 }
538
539 #[tokio::test]
540 async fn test_dev_mode_skips_signature() {
541 let config = AuthConfig::dev_mode();
542 let middleware = AuthMiddleware::new(config);
543
544 let claims = create_test_claims(false);
546 let token = create_test_token(&claims, "any-secret");
547
548 let result = middleware.validate_token_async(&token).await;
550 assert!(result.is_ok());
551 }
552
553 #[tokio::test]
554 async fn test_dev_mode_still_checks_expiration() {
555 let config = AuthConfig::dev_mode();
556 let middleware = AuthMiddleware::new(config);
557
558 let claims = create_test_claims(true); let token = create_test_token(&claims, "any-secret");
560
561 let result = middleware.validate_token_async(&token).await;
562 assert!(result.is_err());
563 match result {
564 Err(AuthError::TokenExpired) => {}
565 _ => panic!("Expected TokenExpired error even in dev mode"),
566 }
567 }
568
569 #[tokio::test]
570 async fn test_invalid_token_format() {
571 let config = AuthConfig::with_secret("secret");
572 let middleware = AuthMiddleware::new(config);
573
574 let result = middleware.validate_token_async("not-a-valid-jwt").await;
575 assert!(result.is_err());
576 match result {
577 Err(AuthError::InvalidToken(_)) => {}
578 _ => panic!("Expected InvalidToken error"),
579 }
580 }
581
582 #[test]
583 fn test_algorithm_conversion() {
584 assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
586 assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
587 assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
588 assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
590 assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
591 assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
592 }
593
594 #[test]
595 fn test_is_hmac_and_is_rsa() {
596 let hmac_config = AuthConfig::with_secret("test");
597 assert!(hmac_config.is_hmac());
598 assert!(!hmac_config.is_rsa());
599
600 let rsa_config = AuthConfig {
601 algorithm: JwtAlgorithm::RS256,
602 ..Default::default()
603 };
604 assert!(!rsa_config.is_hmac());
605 assert!(rsa_config.is_rsa());
606 }
607
608 #[test]
609 fn test_extract_token_rejects_non_bearer_header() {
610 let req = Request::builder()
611 .header(axum::http::header::AUTHORIZATION, "Basic abc")
612 .body(Body::empty())
613 .unwrap();
614
615 let result = extract_token(&req);
616 assert!(matches!(result, Err(AuthError::InvalidHeader)));
617 }
618
619 #[tokio::test]
620 async fn test_extract_auth_context_async_invalid_token_errors() {
621 let middleware = AuthMiddleware::new(AuthConfig::with_secret("secret"));
622 let result = extract_auth_context_async(Some("bad.token".to_string()), &middleware).await;
623 assert!(matches!(result, Err(AuthError::InvalidToken(_))));
624 }
625}