1use std::marker::PhantomData;
2
3use async_trait::async_trait;
4use axum::extract::FromRef;
5use axum::http::{header::HeaderName, StatusCode};
6use axum::response::Response;
7use axum::RequestPartsExt;
8use axum::{http::request::Parts, response::IntoResponse};
9use axum_extra::headers::authorization::Bearer;
10use axum_extra::headers::{Authorization, Cookie};
11use axum_extra::TypedHeader;
12use jsonwebtoken::errors::ErrorKind;
13use serde::de::DeserializeOwned;
14
15use crate::Decoder;
16
17#[derive(Debug)]
43pub struct Claims<T, E = BearerTokenExtractor> {
44 pub claims: T,
46 _extractor: PhantomData<E>,
47}
48
49#[async_trait]
54pub trait TokenExtractor {
55 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError>;
59}
60
61pub struct BearerTokenExtractor;
65
66#[async_trait]
67impl TokenExtractor for BearerTokenExtractor {
68 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
69 let auth: TypedHeader<Authorization<Bearer>> =
70 parts.extract().await.map_err(|_| AuthError::MissingToken)?;
71
72 Ok(auth.token().to_string())
73 }
74}
75
76pub trait ExtractorConfig {
81 fn value() -> &'static str;
83}
84
85#[macro_export]
99macro_rules! define_header_extractor {
100 ($name:ident, $header:expr) => {
101 pub struct $name;
102 impl $crate::ExtractorConfig for $name {
103 fn value() -> &'static str {
104 $header
105 }
106 }
107 };
108}
109
110#[macro_export]
124macro_rules! define_cookie_extractor {
125 ($name:ident, $cookie:expr) => {
126 pub struct $name;
127 impl $crate::ExtractorConfig for $name {
128 fn value() -> &'static str {
129 $cookie
130 }
131 }
132 };
133}
134
135pub struct HeaderTokenExtractor<C: ExtractorConfig>(PhantomData<C>);
149
150#[async_trait]
151impl<C: ExtractorConfig> TokenExtractor for HeaderTokenExtractor<C> {
152 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
153 let header_name = HeaderName::from_static(C::value());
154
155 parts
156 .headers
157 .get(&header_name)
158 .and_then(|h| h.to_str().ok())
159 .map(|s| s.to_string())
160 .ok_or(AuthError::MissingToken)
161 }
162}
163
164pub struct CookieTokenExtractor<C: ExtractorConfig>(PhantomData<C>);
178
179#[async_trait]
180impl<C: ExtractorConfig> TokenExtractor for CookieTokenExtractor<C> {
181 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
182 let cookies: TypedHeader<Cookie> =
183 parts.extract().await.map_err(|_| AuthError::MissingToken)?;
184
185 cookies
186 .get(C::value())
187 .map(|s| s.to_string())
188 .ok_or(AuthError::MissingToken)
189 }
190}
191
192impl<S, T, E> axum::extract::FromRequestParts<S> for Claims<T, E>
193where
194 JwtDecoderState<T>: FromRef<S>,
195 S: Send + Sync,
196 T: DeserializeOwned,
197 E: TokenExtractor,
198{
199 type Rejection = AuthError;
200
201 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
202 let token = E::extract_token(parts).await?;
203
204 let state = JwtDecoderState::from_ref(state);
205 let token_data = state
206 .decoder
207 .clone()
208 .decode(&token)
209 .await
210 .map_err(map_jwt_error)?;
211
212 Ok(Claims {
213 claims: token_data.claims,
214 _extractor: PhantomData,
215 })
216 }
217}
218
219fn map_jwt_error(err: crate::Error) -> AuthError {
221 match err {
222 crate::Error::Jwt(e) => match e.kind() {
223 ErrorKind::ExpiredSignature => AuthError::ExpiredSignature,
224 ErrorKind::InvalidSignature => AuthError::InvalidSignature,
225 ErrorKind::InvalidAudience => AuthError::InvalidAudience,
226 ErrorKind::InvalidAlgorithm => AuthError::InvalidAlgorithm,
227 ErrorKind::InvalidToken => AuthError::InvalidToken,
228 ErrorKind::InvalidIssuer => AuthError::InvalidIssuer,
229 ErrorKind::InvalidSubject => AuthError::InvalidSubject,
230 ErrorKind::ImmatureSignature => AuthError::ImmatureSignature,
231 ErrorKind::MissingAlgorithm => AuthError::MissingAlgorithm,
232 ErrorKind::MissingRequiredClaim(claim) => {
233 AuthError::MissingRequiredClaim(claim.to_string())
234 }
235 _ => AuthError::InternalError,
236 },
237 _ => AuthError::InternalError,
238 }
239}
240
241#[derive(Debug, PartialEq, thiserror::Error)]
245pub enum AuthError {
246 #[error("Invalid token")]
248 InvalidToken,
249
250 #[error("Invalid signature")]
252 InvalidSignature,
253
254 #[error("Missing required claim: {0}")]
256 MissingRequiredClaim(String),
257
258 #[error("Expired signature")]
260 ExpiredSignature,
261
262 #[error("Invalid issuer")]
264 InvalidIssuer,
265
266 #[error("Invalid audience")]
268 InvalidAudience,
269
270 #[error("Invalid subject")]
272 InvalidSubject,
273
274 #[error("Immature signature")]
276 ImmatureSignature,
277
278 #[error("Invalid algorithm")]
280 InvalidAlgorithm,
281
282 #[error("Missing algorithm")]
284 MissingAlgorithm,
285
286 #[error("Missing token")]
288 MissingToken,
289
290 #[error("Internal error")]
292 InternalError,
293}
294
295impl IntoResponse for AuthError {
296 fn into_response(self) -> Response {
297 let (status, msg) = match self {
298 AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
299 AuthError::InvalidSignature => (StatusCode::UNAUTHORIZED, "Invalid signature"),
300 AuthError::MissingRequiredClaim(_) => {
301 (StatusCode::UNAUTHORIZED, "Missing required claim")
302 }
303 AuthError::ExpiredSignature => (StatusCode::UNAUTHORIZED, "Expired signature"),
304 AuthError::InvalidIssuer => (StatusCode::UNAUTHORIZED, "Invalid issuer"),
305 AuthError::InvalidAudience => (StatusCode::UNAUTHORIZED, "Invalid audience"),
306 AuthError::InvalidSubject => (StatusCode::UNAUTHORIZED, "Invalid subject"),
307 AuthError::ImmatureSignature => (StatusCode::UNAUTHORIZED, "Immature signature"),
308 AuthError::InvalidAlgorithm => (StatusCode::UNAUTHORIZED, "Invalid algorithm"),
309 AuthError::MissingAlgorithm => (StatusCode::UNAUTHORIZED, "Missing algorithm"),
310 AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"),
311 AuthError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"),
312 };
313
314 (status, msg).into_response()
315 }
316}
317
318#[derive(Clone)]
322pub struct JwtDecoderState<T> {
323 pub decoder: Decoder<T>,
325}
326
327impl<T> FromRef<JwtDecoderState<T>> for Decoder<T> {
328 fn from_ref(state: &JwtDecoderState<T>) -> Self {
329 state.decoder.clone()
330 }
331}
332
333#[cfg(test)]
334mod tests {
335
336 use super::*;
337 use axum::body::Body;
338 use axum::extract::Request;
339
340 #[test]
345 fn test_header_extractor_macro() {
346 define_header_extractor!(TestHeader, "x-test-header");
347 assert_eq!(TestHeader::value(), "x-test-header");
348 }
349
350 #[test]
351 fn test_cookie_extractor_macro() {
352 define_cookie_extractor!(TestCookie, "test_cookie");
353 assert_eq!(TestCookie::value(), "test_cookie");
354 }
355
356 #[tokio::test]
361 async fn test_map_jwt_error_expired_signature() {
362 use jsonwebtoken::errors::Error as JwtError;
363
364 let jwt_error = JwtError::from(ErrorKind::ExpiredSignature);
365 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
366 assert_eq!(auth_error, AuthError::ExpiredSignature);
367 }
368
369 #[tokio::test]
370 async fn test_map_jwt_error_invalid_signature() {
371 use jsonwebtoken::errors::Error as JwtError;
372
373 let jwt_error = JwtError::from(ErrorKind::InvalidSignature);
374 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
375 assert_eq!(auth_error, AuthError::InvalidSignature);
376 }
377
378 #[tokio::test]
379 async fn test_map_jwt_error_invalid_audience() {
380 use jsonwebtoken::errors::Error as JwtError;
381
382 let jwt_error = JwtError::from(ErrorKind::InvalidAudience);
383 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
384 assert_eq!(auth_error, AuthError::InvalidAudience);
385 }
386
387 #[tokio::test]
388 async fn test_map_jwt_error_invalid_algorithm() {
389 use jsonwebtoken::errors::Error as JwtError;
390
391 let jwt_error = JwtError::from(ErrorKind::InvalidAlgorithm);
392 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
393 assert_eq!(auth_error, AuthError::InvalidAlgorithm);
394 }
395
396 #[tokio::test]
397 async fn test_map_jwt_error_invalid_token() {
398 use jsonwebtoken::errors::Error as JwtError;
399
400 let jwt_error = JwtError::from(ErrorKind::InvalidToken);
401 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
402 assert_eq!(auth_error, AuthError::InvalidToken);
403 }
404
405 #[tokio::test]
406 async fn test_map_jwt_error_invalid_issuer() {
407 use jsonwebtoken::errors::Error as JwtError;
408
409 let jwt_error = JwtError::from(ErrorKind::InvalidIssuer);
410 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
411 assert_eq!(auth_error, AuthError::InvalidIssuer);
412 }
413
414 #[tokio::test]
415 async fn test_map_jwt_error_invalid_subject() {
416 use jsonwebtoken::errors::Error as JwtError;
417
418 let jwt_error = JwtError::from(ErrorKind::InvalidSubject);
419 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
420 assert_eq!(auth_error, AuthError::InvalidSubject);
421 }
422
423 #[tokio::test]
424 async fn test_map_jwt_error_immature_signature() {
425 use jsonwebtoken::errors::Error as JwtError;
426
427 let jwt_error = JwtError::from(ErrorKind::ImmatureSignature);
428 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
429 assert_eq!(auth_error, AuthError::ImmatureSignature);
430 }
431
432 #[tokio::test]
433 async fn test_map_jwt_error_missing_algorithm() {
434 use jsonwebtoken::errors::Error as JwtError;
435
436 let jwt_error = JwtError::from(ErrorKind::MissingAlgorithm);
437 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
438 assert_eq!(auth_error, AuthError::MissingAlgorithm);
439 }
440
441 #[tokio::test]
442 async fn test_map_jwt_error_missing_required_claim() {
443 use jsonwebtoken::errors::Error as JwtError;
444
445 let jwt_error = JwtError::from(ErrorKind::MissingRequiredClaim("sub".to_string()));
446 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
447 assert_eq!(
448 auth_error,
449 AuthError::MissingRequiredClaim("sub".to_string())
450 );
451 }
452
453 #[tokio::test]
454 async fn test_map_jwt_error_non_jwt_error() {
455 let error = crate::Error::KeyNotFound(Some("test_kid".to_string()));
456 let auth_error = map_jwt_error(error);
457 assert_eq!(auth_error, AuthError::InternalError);
458 }
459
460 #[tokio::test]
465 async fn test_bearer_token_extractor_valid() {
466 let req = Request::builder()
467 .header("Authorization", "Bearer test_token")
468 .body(Body::empty())
469 .unwrap();
470
471 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
472 assert!(token.is_ok());
473 assert_eq!(token.unwrap(), "test_token");
474 }
475
476 #[tokio::test]
477 async fn test_bearer_token_extractor_valid_long_token() {
478 let long_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
479 let req = Request::builder()
480 .header("Authorization", format!("Bearer {}", long_token))
481 .body(Body::empty())
482 .unwrap();
483
484 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
485 assert!(token.is_ok());
486 assert_eq!(token.unwrap(), long_token);
487 }
488
489 #[tokio::test]
490 async fn test_bearer_token_extractor_invalid_scheme() {
491 let req = Request::builder()
492 .header("Authorization", "Basic dXNlcjpwYXNz")
493 .body(Body::empty())
494 .unwrap();
495
496 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
497 assert!(token.is_err());
498 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
499 }
500
501 #[tokio::test]
502 async fn test_bearer_token_extractor_malformed_header() {
503 let req = Request::builder()
504 .header("Authorization", "BearerMissingSpace")
505 .body(Body::empty())
506 .unwrap();
507
508 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
509 assert!(token.is_err());
510 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
511 }
512
513 #[tokio::test]
514 async fn test_bearer_token_extractor_missing_header() {
515 let req = Request::builder().body(Body::empty()).unwrap();
516 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
517 assert!(token.is_err());
518 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
519 }
520
521 #[tokio::test]
522 async fn test_bearer_token_extractor_empty_token() {
523 let req = Request::builder()
524 .header("Authorization", "Bearer ")
525 .body(Body::empty())
526 .unwrap();
527
528 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
529 assert!(token.is_ok());
531 assert_eq!(token.unwrap(), "");
532 }
533
534 #[tokio::test]
535 async fn test_bearer_token_extractor_case_sensitivity() {
536 let req = Request::builder()
538 .header("Authorization", "bearer test_token")
539 .body(Body::empty())
540 .unwrap();
541
542 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
543 assert!(token.is_ok());
545 assert_eq!(token.unwrap(), "test_token");
546 }
547
548 #[tokio::test]
553 async fn test_header_token_extractor_valid() {
554 define_header_extractor!(XAuthToken, "x-auth-token");
555 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken>;
556
557 let req = Request::builder()
558 .header("x-auth-token", "test_token_123")
559 .body(Body::empty())
560 .unwrap();
561
562 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
563 assert!(token.is_ok());
564 assert_eq!(token.unwrap(), "test_token_123");
565 }
566
567 #[tokio::test]
568 async fn test_header_token_extractor_missing_header() {
569 define_header_extractor!(XAuthToken2, "x-auth-token");
570 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken2>;
571
572 let req = Request::builder().body(Body::empty()).unwrap();
573 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
574 assert!(token.is_err());
575 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
576 }
577
578 #[tokio::test]
579 async fn test_header_token_extractor_empty_value() {
580 define_header_extractor!(XAuthToken3, "x-auth-token");
581 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken3>;
582
583 let req = Request::builder()
584 .header("x-auth-token", "")
585 .body(Body::empty())
586 .unwrap();
587
588 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
589 assert!(token.is_ok());
590 assert_eq!(token.unwrap(), "");
591 }
592
593 #[tokio::test]
594 async fn test_header_token_extractor_special_characters() {
595 define_header_extractor!(XAuthToken4, "x-auth-token");
596 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken4>;
597
598 let req = Request::builder()
599 .header("x-auth-token", "token-with-special.chars_123")
600 .body(Body::empty())
601 .unwrap();
602
603 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
604 assert!(token.is_ok());
605 assert_eq!(token.unwrap(), "token-with-special.chars_123");
606 }
607
608 #[tokio::test]
609 async fn test_header_token_extractor_different_header_names() {
610 define_header_extractor!(ApiKey, "x-api-key");
611 type ApiKeyExtractor = HeaderTokenExtractor<ApiKey>;
612
613 let req = Request::builder()
614 .header("x-api-key", "api_key_value")
615 .header("x-auth-token", "auth_token_value")
616 .body(Body::empty())
617 .unwrap();
618
619 let token = ApiKeyExtractor::extract_token(&mut req.into_parts().0).await;
620 assert!(token.is_ok());
621 assert_eq!(token.unwrap(), "api_key_value");
622 }
623
624 #[tokio::test]
629 async fn test_cookie_token_extractor_valid() {
630 define_cookie_extractor!(AuthTokenCookie, "auth_token");
631 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie>;
632
633 let req = Request::builder()
634 .header("Cookie", "auth_token=my_jwt_token; other=value")
635 .body(Body::empty())
636 .unwrap();
637
638 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
639 assert!(token.is_ok());
640 assert_eq!(token.unwrap(), "my_jwt_token");
641 }
642
643 #[tokio::test]
644 async fn test_cookie_token_extractor_single_cookie() {
645 define_cookie_extractor!(AuthTokenCookie2, "auth_token");
646 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie2>;
647
648 let req = Request::builder()
649 .header("Cookie", "auth_token=my_jwt_token")
650 .body(Body::empty())
651 .unwrap();
652
653 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
654 assert!(token.is_ok());
655 assert_eq!(token.unwrap(), "my_jwt_token");
656 }
657
658 #[tokio::test]
659 async fn test_cookie_token_extractor_missing_cookie() {
660 define_cookie_extractor!(AuthTokenCookie3, "auth_token");
661 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie3>;
662
663 let req = Request::builder()
664 .header("Cookie", "other=value")
665 .body(Body::empty())
666 .unwrap();
667 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
668 assert!(token.is_err());
669 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
670 }
671
672 #[tokio::test]
673 async fn test_cookie_token_extractor_no_cookies() {
674 define_cookie_extractor!(AuthTokenCookie4, "auth_token");
675 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie4>;
676
677 let req = Request::builder().body(Body::empty()).unwrap();
678 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
679 assert!(token.is_err());
680 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
681 }
682
683 #[tokio::test]
684 async fn test_cookie_token_extractor_multiple_cookies() {
685 define_cookie_extractor!(AuthTokenCookie5, "auth_token");
686 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie5>;
687
688 let req = Request::builder()
689 .header("Cookie", "session=abc123; auth_token=my_jwt; user_id=456")
690 .body(Body::empty())
691 .unwrap();
692
693 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
694 assert!(token.is_ok());
695 assert_eq!(token.unwrap(), "my_jwt");
696 }
697
698 #[tokio::test]
699 async fn test_cookie_token_extractor_empty_value() {
700 define_cookie_extractor!(AuthTokenCookie6, "auth_token");
701 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie6>;
702
703 let req = Request::builder()
704 .header("Cookie", "auth_token=")
705 .body(Body::empty())
706 .unwrap();
707
708 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
709 assert!(token.is_ok());
710 assert_eq!(token.unwrap(), "");
711 }
712
713 #[tokio::test]
714 async fn test_cookie_token_extractor_with_spaces() {
715 define_cookie_extractor!(AuthTokenCookie7, "auth_token");
716 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie7>;
717
718 let req = Request::builder()
719 .header("Cookie", "auth_token=my_jwt_token; other=value")
720 .body(Body::empty())
721 .unwrap();
722
723 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
724 assert!(token.is_ok());
725 assert_eq!(token.unwrap(), "my_jwt_token");
726 }
727
728 #[tokio::test]
733 async fn test_auth_error_invalid_token_response() {
734 let response = AuthError::InvalidToken.into_response();
735 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
736 }
737
738 #[tokio::test]
739 async fn test_auth_error_expired_signature_response() {
740 let response = AuthError::ExpiredSignature.into_response();
741 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
742 }
743
744 #[tokio::test]
745 async fn test_auth_error_internal_error_response() {
746 let response = AuthError::InternalError.into_response();
747 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
748 }
749
750 #[tokio::test]
751 async fn test_auth_error_missing_token_response() {
752 let response = AuthError::MissingToken.into_response();
753 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
754 }
755
756 #[tokio::test]
757 async fn test_auth_error_missing_required_claim_response() {
758 let response = AuthError::MissingRequiredClaim("sub".to_string()).into_response();
759 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
760 }
761}