1use std::sync::Arc;
2
3use axum::{
4 body::Body,
5 extract::{Request, State},
6 middleware::Next,
7 response::Response,
8};
9use forge_core::auth::Claims;
10use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
11use forge_core::function::AuthContext;
12use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode};
13use tracing::debug;
14
15use super::jwks::JwksClient;
16
17#[derive(Debug, Clone)]
19pub struct AuthConfig {
20 pub jwt_secret: Option<String>,
22 pub algorithm: JwtAlgorithm,
24 pub jwks_client: Option<Arc<JwksClient>>,
26 pub issuer: Option<String>,
28 pub audience: Option<String>,
30 pub allow_anonymous: bool,
32 pub skip_verification: bool,
34}
35
36impl Default for AuthConfig {
37 fn default() -> Self {
38 Self {
39 jwt_secret: None,
40 algorithm: JwtAlgorithm::HS256,
41 jwks_client: None,
42 issuer: None,
43 audience: None,
44 allow_anonymous: true,
45 skip_verification: false,
46 }
47 }
48}
49
50impl AuthConfig {
51 pub fn from_forge_config(config: &forge_core::config::AuthConfig) -> Self {
53 let algorithm = JwtAlgorithm::from(config.algorithm);
54
55 let jwks_client = config
56 .jwks_url
57 .as_ref()
58 .map(|url| Arc::new(JwksClient::new(url.clone(), config.jwks_cache_ttl_secs)));
59
60 Self {
61 jwt_secret: config.jwt_secret.clone(),
62 algorithm,
63 jwks_client,
64 issuer: config.issuer.clone(),
65 audience: config.audience.clone(),
66 allow_anonymous: config.allow_anonymous,
67 skip_verification: false,
68 }
69 }
70
71 pub fn with_secret(secret: impl Into<String>) -> Self {
73 Self {
74 jwt_secret: Some(secret.into()),
75 ..Default::default()
76 }
77 }
78
79 pub fn dev_mode() -> Self {
82 Self {
83 jwt_secret: None,
84 algorithm: JwtAlgorithm::HS256,
85 jwks_client: None,
86 issuer: None,
87 audience: None,
88 allow_anonymous: true,
89 skip_verification: true,
90 }
91 }
92
93 pub fn is_hmac(&self) -> bool {
95 matches!(
96 self.algorithm,
97 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
98 )
99 }
100
101 pub fn is_rsa(&self) -> bool {
103 matches!(
104 self.algorithm,
105 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
106 )
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
112pub enum JwtAlgorithm {
113 #[default]
114 HS256,
115 HS384,
116 HS512,
117 RS256,
118 RS384,
119 RS512,
120}
121
122impl From<JwtAlgorithm> for Algorithm {
123 fn from(alg: JwtAlgorithm) -> Self {
124 match alg {
125 JwtAlgorithm::HS256 => Algorithm::HS256,
126 JwtAlgorithm::HS384 => Algorithm::HS384,
127 JwtAlgorithm::HS512 => Algorithm::HS512,
128 JwtAlgorithm::RS256 => Algorithm::RS256,
129 JwtAlgorithm::RS384 => Algorithm::RS384,
130 JwtAlgorithm::RS512 => Algorithm::RS512,
131 }
132 }
133}
134
135impl From<CoreJwtAlgorithm> for JwtAlgorithm {
136 fn from(alg: CoreJwtAlgorithm) -> Self {
137 match alg {
138 CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
139 CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
140 CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
141 CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
142 CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
143 CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
144 }
145 }
146}
147
148#[derive(Clone)]
150pub struct AuthMiddleware {
151 config: Arc<AuthConfig>,
152 hmac_key: Option<DecodingKey>,
154}
155
156impl std::fmt::Debug for AuthMiddleware {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 f.debug_struct("AuthMiddleware")
159 .field("config", &self.config)
160 .field("hmac_key", &self.hmac_key.is_some())
161 .finish()
162 }
163}
164
165impl AuthMiddleware {
166 pub fn new(config: AuthConfig) -> Self {
168 let hmac_key = if config.skip_verification {
170 None
171 } else if config.is_hmac() {
172 config
173 .jwt_secret
174 .as_ref()
175 .filter(|s| !s.is_empty())
176 .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
177 } else {
178 None
179 };
180
181 Self {
182 config: Arc::new(config),
183 hmac_key,
184 }
185 }
186
187 pub fn permissive() -> Self {
190 Self::new(AuthConfig::dev_mode())
191 }
192
193 pub fn config(&self) -> &AuthConfig {
195 &self.config
196 }
197
198 pub fn validate_token(&self, token: &str) -> Result<Claims, AuthError> {
201 if self.config.skip_verification {
202 return self.decode_without_verification(token);
203 }
204
205 if self.config.is_hmac() {
206 self.validate_hmac(token)
207 } else {
208 Err(AuthError::InvalidToken(
210 "RSA validation requires async. Use validate_token_async.".to_string(),
211 ))
212 }
213 }
214
215 pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
217 if self.config.skip_verification {
218 return self.decode_without_verification(token);
219 }
220
221 if self.config.is_hmac() {
222 self.validate_hmac(token)
223 } else {
224 self.validate_rsa(token).await
225 }
226 }
227
228 fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
230 let key = self.hmac_key.as_ref().ok_or_else(|| {
231 AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
232 })?;
233
234 self.decode_and_validate(token, key)
235 }
236
237 async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
239 let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
240 AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
241 })?;
242
243 let header = jsonwebtoken::decode_header(token)
245 .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
246
247 debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
248
249 let key = if let Some(kid) = header.kid {
251 jwks.get_key(&kid).await.map_err(|e| {
252 AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
253 })?
254 } else {
255 jwks.get_any_key()
256 .await
257 .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
258 };
259
260 self.decode_and_validate(token, &key)
261 }
262
263 fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
265 let mut validation = Validation::new(self.config.algorithm.into());
266
267 validation.validate_exp = true;
269 validation.validate_nbf = false;
270 validation.leeway = 60; validation.set_required_spec_claims(&["exp", "sub"]);
274
275 if let Some(ref issuer) = self.config.issuer {
277 validation.set_issuer(&[issuer]);
278 }
279
280 if let Some(ref audience) = self.config.audience {
282 validation.set_audience(&[audience]);
283 } else {
284 validation.validate_aud = false;
285 }
286
287 let token_data =
288 decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
289
290 Ok(token_data.claims)
291 }
292
293 fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
295 match e.kind() {
296 jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
297 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
298 AuthError::InvalidToken("Invalid signature".to_string())
299 }
300 jsonwebtoken::errors::ErrorKind::InvalidToken => {
301 AuthError::InvalidToken("Invalid token format".to_string())
302 }
303 jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
304 AuthError::InvalidToken(format!("Missing required claim: {}", claim))
305 }
306 jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
307 AuthError::InvalidToken("Invalid issuer".to_string())
308 }
309 jsonwebtoken::errors::ErrorKind::InvalidAudience => {
310 AuthError::InvalidToken("Invalid audience".to_string())
311 }
312 _ => AuthError::InvalidToken(e.to_string()),
313 }
314 }
315
316 fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
318 let token_data =
319 dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
320 jsonwebtoken::errors::ErrorKind::InvalidToken => {
321 AuthError::InvalidToken("Invalid token format".to_string())
322 }
323 _ => AuthError::InvalidToken(e.to_string()),
324 })?;
325
326 if token_data.claims.is_expired() {
328 return Err(AuthError::TokenExpired);
329 }
330
331 Ok(token_data.claims)
332 }
333}
334
335#[derive(Debug, Clone, thiserror::Error)]
337pub enum AuthError {
338 #[error("Missing authorization header")]
339 MissingHeader,
340 #[error("Invalid authorization header format")]
341 InvalidHeader,
342 #[error("Invalid token: {0}")]
343 InvalidToken(String),
344 #[error("Token expired")]
345 TokenExpired,
346}
347
348pub fn extract_token(req: &Request<Body>) -> Option<String> {
350 req.headers()
351 .get(axum::http::header::AUTHORIZATION)
352 .and_then(|v| v.to_str().ok())
353 .filter(|header| header.starts_with("Bearer "))
354 .map(|header| header.trim_start_matches("Bearer ").trim().to_string())
355}
356
357pub async fn extract_auth_context_async(
359 token: Option<String>,
360 middleware: &AuthMiddleware,
361) -> AuthContext {
362 match token {
363 Some(token) => match middleware.validate_token_async(&token).await {
364 Ok(claims) => build_auth_context_from_claims(claims),
365 Err(e) => {
366 tracing::warn!(error = %e, "Token validation failed");
367 AuthContext::unauthenticated()
368 }
369 },
370 None => AuthContext::unauthenticated(),
371 }
372}
373
374pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
380 let user_id = claims.user_id();
382
383 let mut custom_claims = claims.custom;
385 custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
386
387 match user_id {
388 Some(uuid) => {
389 AuthContext::authenticated(uuid, claims.roles, custom_claims)
391 }
392 None => {
393 AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
396 }
397 }
398}
399
400pub async fn auth_middleware(
402 State(middleware): State<Arc<AuthMiddleware>>,
403 req: Request<Body>,
404 next: Next,
405) -> Response {
406 let token = extract_token(&req);
407 tracing::trace!(
408 token_present = token.is_some(),
409 "Auth middleware processing request"
410 );
411
412 let auth_context = extract_auth_context_async(token, &middleware).await;
413 tracing::trace!(
414 authenticated = auth_context.is_authenticated(),
415 "Auth context created"
416 );
417
418 let mut req = req;
419 req.extensions_mut().insert(auth_context);
420
421 next.run(req).await
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use jsonwebtoken::{EncodingKey, Header, encode};
428
429 fn create_test_claims(expired: bool) -> Claims {
430 use forge_core::auth::ClaimsBuilder;
431
432 let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
433
434 if expired {
435 builder = builder.duration_secs(-3600); } else {
437 builder = builder.duration_secs(3600); }
439
440 builder.build().unwrap()
441 }
442
443 fn create_test_token(claims: &Claims, secret: &str) -> String {
444 encode(
445 &Header::default(),
446 claims,
447 &EncodingKey::from_secret(secret.as_bytes()),
448 )
449 .unwrap()
450 }
451
452 #[test]
453 fn test_auth_config_default() {
454 let config = AuthConfig::default();
455 assert!(config.allow_anonymous);
456 assert_eq!(config.algorithm, JwtAlgorithm::HS256);
457 assert!(!config.skip_verification);
458 }
459
460 #[test]
461 fn test_auth_config_dev_mode() {
462 let config = AuthConfig::dev_mode();
463 assert!(config.skip_verification);
464 assert!(config.allow_anonymous);
465 }
466
467 #[test]
468 fn test_auth_middleware_permissive() {
469 let middleware = AuthMiddleware::permissive();
470 assert!(middleware.config.skip_verification);
471 }
472
473 #[test]
474 fn test_valid_token_with_correct_secret() {
475 let secret = "test-secret-key";
476 let config = AuthConfig::with_secret(secret);
477 let middleware = AuthMiddleware::new(config);
478
479 let claims = create_test_claims(false);
480 let token = create_test_token(&claims, secret);
481
482 let result = middleware.validate_token(&token);
483 assert!(result.is_ok());
484 let validated_claims = result.unwrap();
485 assert_eq!(validated_claims.sub, "test-user-id");
486 }
487
488 #[test]
489 fn test_valid_token_with_wrong_secret() {
490 let config = AuthConfig::with_secret("correct-secret");
491 let middleware = AuthMiddleware::new(config);
492
493 let claims = create_test_claims(false);
494 let token = create_test_token(&claims, "wrong-secret");
495
496 let result = middleware.validate_token(&token);
497 assert!(result.is_err());
498 match result {
499 Err(AuthError::InvalidToken(_)) => {}
500 _ => panic!("Expected InvalidToken error"),
501 }
502 }
503
504 #[test]
505 fn test_expired_token() {
506 let secret = "test-secret";
507 let config = AuthConfig::with_secret(secret);
508 let middleware = AuthMiddleware::new(config);
509
510 let claims = create_test_claims(true); let token = create_test_token(&claims, secret);
512
513 let result = middleware.validate_token(&token);
514 assert!(result.is_err());
515 match result {
516 Err(AuthError::TokenExpired) => {}
517 _ => panic!("Expected TokenExpired error"),
518 }
519 }
520
521 #[test]
522 fn test_tampered_token() {
523 let secret = "test-secret";
524 let config = AuthConfig::with_secret(secret);
525 let middleware = AuthMiddleware::new(config);
526
527 let claims = create_test_claims(false);
528 let mut token = create_test_token(&claims, secret);
529
530 if let Some(last_char) = token.pop() {
532 let replacement = if last_char == 'a' { 'b' } else { 'a' };
533 token.push(replacement);
534 }
535
536 let result = middleware.validate_token(&token);
537 assert!(result.is_err());
538 }
539
540 #[test]
541 fn test_dev_mode_skips_signature() {
542 let config = AuthConfig::dev_mode();
543 let middleware = AuthMiddleware::new(config);
544
545 let claims = create_test_claims(false);
547 let token = create_test_token(&claims, "any-secret");
548
549 let result = middleware.validate_token(&token);
551 assert!(result.is_ok());
552 }
553
554 #[test]
555 fn test_dev_mode_still_checks_expiration() {
556 let config = AuthConfig::dev_mode();
557 let middleware = AuthMiddleware::new(config);
558
559 let claims = create_test_claims(true); let token = create_test_token(&claims, "any-secret");
561
562 let result = middleware.validate_token(&token);
563 assert!(result.is_err());
564 match result {
565 Err(AuthError::TokenExpired) => {}
566 _ => panic!("Expected TokenExpired error even in dev mode"),
567 }
568 }
569
570 #[test]
571 fn test_invalid_token_format() {
572 let config = AuthConfig::with_secret("secret");
573 let middleware = AuthMiddleware::new(config);
574
575 let result = middleware.validate_token("not-a-valid-jwt");
576 assert!(result.is_err());
577 match result {
578 Err(AuthError::InvalidToken(_)) => {}
579 _ => panic!("Expected InvalidToken error"),
580 }
581 }
582
583 #[test]
584 fn test_algorithm_conversion() {
585 assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
587 assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
588 assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
589 assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
591 assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
592 assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
593 }
594
595 #[test]
596 fn test_is_hmac_and_is_rsa() {
597 let hmac_config = AuthConfig::with_secret("test");
598 assert!(hmac_config.is_hmac());
599 assert!(!hmac_config.is_rsa());
600
601 let rsa_config = AuthConfig {
602 algorithm: JwtAlgorithm::RS256,
603 ..Default::default()
604 };
605 assert!(!rsa_config.is_hmac());
606 assert!(rsa_config.is_rsa());
607 }
608}