1use std::sync::Arc;
2
3use axum::{
4 body::Body,
5 extract::{Request, State},
6 middleware::Next,
7 response::Response,
8};
9use forge_core::auth::Claims;
10use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
11use forge_core::function::AuthContext;
12use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode};
13use tracing::debug;
14
15use super::jwks::JwksClient;
16
17#[derive(Debug, Clone)]
19pub struct AuthConfig {
20 pub jwt_secret: Option<String>,
22 pub algorithm: JwtAlgorithm,
24 pub jwks_client: Option<Arc<JwksClient>>,
26 pub issuer: Option<String>,
28 pub audience: Option<String>,
30 pub skip_verification: bool,
32}
33
34impl Default for AuthConfig {
35 fn default() -> Self {
36 Self {
37 jwt_secret: None,
38 algorithm: JwtAlgorithm::HS256,
39 jwks_client: None,
40 issuer: None,
41 audience: None,
42 skip_verification: false,
43 }
44 }
45}
46
47impl AuthConfig {
48 pub fn from_forge_config(config: &forge_core::config::AuthConfig) -> Self {
50 let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
51
52 let jwks_client = config
53 .jwks_url
54 .as_ref()
55 .map(|url| Arc::new(JwksClient::new(url.clone(), config.jwks_cache_ttl_secs)));
56
57 Self {
58 jwt_secret: config.jwt_secret.clone(),
59 algorithm,
60 jwks_client,
61 issuer: config.jwt_issuer.clone(),
62 audience: config.jwt_audience.clone(),
63 skip_verification: false,
64 }
65 }
66
67 pub fn with_secret(secret: impl Into<String>) -> Self {
69 Self {
70 jwt_secret: Some(secret.into()),
71 ..Default::default()
72 }
73 }
74
75 pub fn dev_mode() -> Self {
78 Self {
79 jwt_secret: None,
80 algorithm: JwtAlgorithm::HS256,
81 jwks_client: None,
82 issuer: None,
83 audience: None,
84 skip_verification: true,
85 }
86 }
87
88 pub fn is_hmac(&self) -> bool {
90 matches!(
91 self.algorithm,
92 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
93 )
94 }
95
96 pub fn is_rsa(&self) -> bool {
98 matches!(
99 self.algorithm,
100 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
101 )
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
107pub enum JwtAlgorithm {
108 #[default]
109 HS256,
110 HS384,
111 HS512,
112 RS256,
113 RS384,
114 RS512,
115}
116
117impl From<JwtAlgorithm> for Algorithm {
118 fn from(alg: JwtAlgorithm) -> Self {
119 match alg {
120 JwtAlgorithm::HS256 => Algorithm::HS256,
121 JwtAlgorithm::HS384 => Algorithm::HS384,
122 JwtAlgorithm::HS512 => Algorithm::HS512,
123 JwtAlgorithm::RS256 => Algorithm::RS256,
124 JwtAlgorithm::RS384 => Algorithm::RS384,
125 JwtAlgorithm::RS512 => Algorithm::RS512,
126 }
127 }
128}
129
130impl From<CoreJwtAlgorithm> for JwtAlgorithm {
131 fn from(alg: CoreJwtAlgorithm) -> Self {
132 match alg {
133 CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
134 CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
135 CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
136 CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
137 CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
138 CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
139 }
140 }
141}
142
143#[derive(Clone)]
145pub struct AuthMiddleware {
146 config: Arc<AuthConfig>,
147 hmac_key: Option<DecodingKey>,
149}
150
151impl std::fmt::Debug for AuthMiddleware {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.debug_struct("AuthMiddleware")
154 .field("config", &self.config)
155 .field("hmac_key", &self.hmac_key.is_some())
156 .finish()
157 }
158}
159
160impl AuthMiddleware {
161 pub fn new(config: AuthConfig) -> Self {
163 let hmac_key = if config.skip_verification {
165 None
166 } else if config.is_hmac() {
167 config
168 .jwt_secret
169 .as_ref()
170 .filter(|s| !s.is_empty())
171 .map(|secret| DecodingKey::from_secret(secret.as_bytes()))
172 } else {
173 None
174 };
175
176 Self {
177 config: Arc::new(config),
178 hmac_key,
179 }
180 }
181
182 pub fn permissive() -> Self {
185 Self::new(AuthConfig::dev_mode())
186 }
187
188 pub fn config(&self) -> &AuthConfig {
190 &self.config
191 }
192
193 pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
195 if self.config.skip_verification {
196 return self.decode_without_verification(token);
197 }
198
199 if self.config.is_hmac() {
200 self.validate_hmac(token)
201 } else {
202 self.validate_rsa(token).await
203 }
204 }
205
206 fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
208 let key = self.hmac_key.as_ref().ok_or_else(|| {
209 AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
210 })?;
211
212 self.decode_and_validate(token, key)
213 }
214
215 async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
217 let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
218 AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
219 })?;
220
221 let header = jsonwebtoken::decode_header(token)
223 .map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
224
225 debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
226
227 let key = if let Some(kid) = header.kid {
229 jwks.get_key(&kid).await.map_err(|e| {
230 AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
231 })?
232 } else {
233 jwks.get_any_key()
234 .await
235 .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
236 };
237
238 self.decode_and_validate(token, &key)
239 }
240
241 fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
243 let mut validation = Validation::new(self.config.algorithm.into());
244
245 validation.validate_exp = true;
247 validation.validate_nbf = false;
248 validation.leeway = 60; validation.set_required_spec_claims(&["exp", "sub"]);
252
253 if let Some(ref issuer) = self.config.issuer {
255 validation.set_issuer(&[issuer]);
256 }
257
258 if let Some(ref audience) = self.config.audience {
260 validation.set_audience(&[audience]);
261 } else {
262 validation.validate_aud = false;
263 }
264
265 let token_data =
266 decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
267
268 Ok(token_data.claims)
269 }
270
271 fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
273 match e.kind() {
274 jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
275 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
276 AuthError::InvalidToken("Invalid signature".to_string())
277 }
278 jsonwebtoken::errors::ErrorKind::InvalidToken => {
279 AuthError::InvalidToken("Invalid token format".to_string())
280 }
281 jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
282 AuthError::InvalidToken(format!("Missing required claim: {}", claim))
283 }
284 jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
285 AuthError::InvalidToken("Invalid issuer".to_string())
286 }
287 jsonwebtoken::errors::ErrorKind::InvalidAudience => {
288 AuthError::InvalidToken("Invalid audience".to_string())
289 }
290 _ => AuthError::InvalidToken(e.to_string()),
291 }
292 }
293
294 fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
296 let token_data =
297 dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
298 jsonwebtoken::errors::ErrorKind::InvalidToken => {
299 AuthError::InvalidToken("Invalid token format".to_string())
300 }
301 _ => AuthError::InvalidToken(e.to_string()),
302 })?;
303
304 if token_data.claims.is_expired() {
306 return Err(AuthError::TokenExpired);
307 }
308
309 Ok(token_data.claims)
310 }
311}
312
313#[derive(Debug, Clone, thiserror::Error)]
315pub enum AuthError {
316 #[error("Missing authorization header")]
317 MissingHeader,
318 #[error("Invalid authorization header format")]
319 InvalidHeader,
320 #[error("Invalid token: {0}")]
321 InvalidToken(String),
322 #[error("Token expired")]
323 TokenExpired,
324}
325
326pub fn extract_token(req: &Request<Body>) -> Option<String> {
328 req.headers()
329 .get(axum::http::header::AUTHORIZATION)
330 .and_then(|v| v.to_str().ok())
331 .filter(|header| header.starts_with("Bearer "))
332 .map(|header| header.trim_start_matches("Bearer ").trim().to_string())
333}
334
335pub async fn extract_auth_context_async(
337 token: Option<String>,
338 middleware: &AuthMiddleware,
339) -> AuthContext {
340 match token {
341 Some(token) => match middleware.validate_token_async(&token).await {
342 Ok(claims) => build_auth_context_from_claims(claims),
343 Err(e) => {
344 tracing::warn!(error = %e, "Token validation failed");
345 AuthContext::unauthenticated()
346 }
347 },
348 None => AuthContext::unauthenticated(),
349 }
350}
351
352pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
358 let user_id = claims.user_id();
360
361 let mut custom_claims = claims.custom;
363 custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
364
365 match user_id {
366 Some(uuid) => {
367 AuthContext::authenticated(uuid, claims.roles, custom_claims)
369 }
370 None => {
371 AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
374 }
375 }
376}
377
378pub async fn auth_middleware(
380 State(middleware): State<Arc<AuthMiddleware>>,
381 req: Request<Body>,
382 next: Next,
383) -> Response {
384 let token = extract_token(&req);
385 tracing::trace!(
386 token_present = token.is_some(),
387 "Auth middleware processing request"
388 );
389
390 let auth_context = extract_auth_context_async(token, &middleware).await;
391 tracing::trace!(
392 authenticated = auth_context.is_authenticated(),
393 "Auth context created"
394 );
395
396 let mut req = req;
397 req.extensions_mut().insert(auth_context);
398
399 next.run(req).await
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use jsonwebtoken::{EncodingKey, Header, encode};
406
407 fn create_test_claims(expired: bool) -> Claims {
408 use forge_core::auth::ClaimsBuilder;
409
410 let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
411
412 if expired {
413 builder = builder.duration_secs(-3600); } else {
415 builder = builder.duration_secs(3600); }
417
418 builder.build().unwrap()
419 }
420
421 fn create_test_token(claims: &Claims, secret: &str) -> String {
422 encode(
423 &Header::default(),
424 claims,
425 &EncodingKey::from_secret(secret.as_bytes()),
426 )
427 .unwrap()
428 }
429
430 #[test]
431 fn test_auth_config_default() {
432 let config = AuthConfig::default();
433 assert_eq!(config.algorithm, JwtAlgorithm::HS256);
434 assert!(!config.skip_verification);
435 }
436
437 #[test]
438 fn test_auth_config_dev_mode() {
439 let config = AuthConfig::dev_mode();
440 assert!(config.skip_verification);
441 }
442
443 #[test]
444 fn test_auth_middleware_permissive() {
445 let middleware = AuthMiddleware::permissive();
446 assert!(middleware.config.skip_verification);
447 }
448
449 #[tokio::test]
450 async fn test_valid_token_with_correct_secret() {
451 let secret = "test-secret-key";
452 let config = AuthConfig::with_secret(secret);
453 let middleware = AuthMiddleware::new(config);
454
455 let claims = create_test_claims(false);
456 let token = create_test_token(&claims, secret);
457
458 let result = middleware.validate_token_async(&token).await;
459 assert!(result.is_ok());
460 let validated_claims = result.unwrap();
461 assert_eq!(validated_claims.sub, "test-user-id");
462 }
463
464 #[tokio::test]
465 async fn test_valid_token_with_wrong_secret() {
466 let config = AuthConfig::with_secret("correct-secret");
467 let middleware = AuthMiddleware::new(config);
468
469 let claims = create_test_claims(false);
470 let token = create_test_token(&claims, "wrong-secret");
471
472 let result = middleware.validate_token_async(&token).await;
473 assert!(result.is_err());
474 match result {
475 Err(AuthError::InvalidToken(_)) => {}
476 _ => panic!("Expected InvalidToken error"),
477 }
478 }
479
480 #[tokio::test]
481 async fn test_expired_token() {
482 let secret = "test-secret";
483 let config = AuthConfig::with_secret(secret);
484 let middleware = AuthMiddleware::new(config);
485
486 let claims = create_test_claims(true); let token = create_test_token(&claims, secret);
488
489 let result = middleware.validate_token_async(&token).await;
490 assert!(result.is_err());
491 match result {
492 Err(AuthError::TokenExpired) => {}
493 _ => panic!("Expected TokenExpired error"),
494 }
495 }
496
497 #[tokio::test]
498 async fn test_tampered_token() {
499 let secret = "test-secret";
500 let config = AuthConfig::with_secret(secret);
501 let middleware = AuthMiddleware::new(config);
502
503 let claims = create_test_claims(false);
504 let mut token = create_test_token(&claims, secret);
505
506 if let Some(last_char) = token.pop() {
508 let replacement = if last_char == 'a' { 'b' } else { 'a' };
509 token.push(replacement);
510 }
511
512 let result = middleware.validate_token_async(&token).await;
513 assert!(result.is_err());
514 }
515
516 #[tokio::test]
517 async fn test_dev_mode_skips_signature() {
518 let config = AuthConfig::dev_mode();
519 let middleware = AuthMiddleware::new(config);
520
521 let claims = create_test_claims(false);
523 let token = create_test_token(&claims, "any-secret");
524
525 let result = middleware.validate_token_async(&token).await;
527 assert!(result.is_ok());
528 }
529
530 #[tokio::test]
531 async fn test_dev_mode_still_checks_expiration() {
532 let config = AuthConfig::dev_mode();
533 let middleware = AuthMiddleware::new(config);
534
535 let claims = create_test_claims(true); let token = create_test_token(&claims, "any-secret");
537
538 let result = middleware.validate_token_async(&token).await;
539 assert!(result.is_err());
540 match result {
541 Err(AuthError::TokenExpired) => {}
542 _ => panic!("Expected TokenExpired error even in dev mode"),
543 }
544 }
545
546 #[tokio::test]
547 async fn test_invalid_token_format() {
548 let config = AuthConfig::with_secret("secret");
549 let middleware = AuthMiddleware::new(config);
550
551 let result = middleware.validate_token_async("not-a-valid-jwt").await;
552 assert!(result.is_err());
553 match result {
554 Err(AuthError::InvalidToken(_)) => {}
555 _ => panic!("Expected InvalidToken error"),
556 }
557 }
558
559 #[test]
560 fn test_algorithm_conversion() {
561 assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
563 assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
564 assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
565 assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
567 assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
568 assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
569 }
570
571 #[test]
572 fn test_is_hmac_and_is_rsa() {
573 let hmac_config = AuthConfig::with_secret("test");
574 assert!(hmac_config.is_hmac());
575 assert!(!hmac_config.is_rsa());
576
577 let rsa_config = AuthConfig {
578 algorithm: JwtAlgorithm::RS256,
579 ..Default::default()
580 };
581 assert!(!rsa_config.is_hmac());
582 assert!(rsa_config.is_rsa());
583 }
584}