1use std::marker::PhantomData;
2
3use axum::RequestPartsExt;
4use axum::extract::FromRef;
5use axum::http::{StatusCode, header::HeaderName};
6use axum::response::Response;
7use axum::{http::request::Parts, response::IntoResponse};
8use axum_extra::TypedHeader;
9use axum_extra::headers::authorization::Bearer;
10use axum_extra::headers::{Authorization, Cookie};
11use jsonwebtoken::errors::ErrorKind;
12use serde::de::DeserializeOwned;
13
14use crate::Decoder;
15
16#[derive(Debug)]
67pub struct Claims<T, E = BearerTokenExtractor> {
68 pub claims: T,
70 _extractor: PhantomData<E>,
71}
72
73pub trait TokenExtractor: Send + Sync {
78 fn extract_token(
82 parts: &mut Parts,
83 ) -> impl std::future::Future<Output = Result<String, AuthError>> + Send;
84}
85
86pub struct BearerTokenExtractor;
90
91impl TokenExtractor for BearerTokenExtractor {
92 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
93 let auth: TypedHeader<Authorization<Bearer>> =
94 parts.extract().await.map_err(|_| AuthError::MissingToken)?;
95
96 Ok(auth.token().to_string())
97 }
98}
99
100pub trait ExtractorConfig: Send + Sync {
105 fn value() -> &'static str;
107}
108
109#[macro_export]
123macro_rules! define_header_extractor {
124 ($name:ident, $header:expr) => {
125 pub struct $name;
126 impl $crate::ExtractorConfig for $name {
127 fn value() -> &'static str {
128 $header
129 }
130 }
131 };
132}
133
134#[macro_export]
148macro_rules! define_cookie_extractor {
149 ($name:ident, $cookie:expr) => {
150 pub struct $name;
151 impl $crate::ExtractorConfig for $name {
152 fn value() -> &'static str {
153 $cookie
154 }
155 }
156 };
157}
158
159pub struct HeaderTokenExtractor<C: ExtractorConfig>(PhantomData<C>);
173
174impl<C: ExtractorConfig> TokenExtractor for HeaderTokenExtractor<C> {
175 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
176 let header_name = HeaderName::from_static(C::value());
177
178 parts
179 .headers
180 .get(&header_name)
181 .and_then(|h| h.to_str().ok())
182 .map(|s| s.to_string())
183 .ok_or(AuthError::MissingToken)
184 }
185}
186
187pub struct CookieTokenExtractor<C: ExtractorConfig>(PhantomData<C>);
201
202impl<C: ExtractorConfig> TokenExtractor for CookieTokenExtractor<C> {
203 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
204 let cookies: TypedHeader<Cookie> =
205 parts.extract().await.map_err(|_| AuthError::MissingToken)?;
206
207 cookies
208 .get(C::value())
209 .map(|s| s.to_string())
210 .ok_or(AuthError::MissingToken)
211 }
212}
213
214impl<S, T, E> axum::extract::FromRequestParts<S> for Claims<T, E>
215where
216 Decoder<T>: FromRef<S>,
217 S: Send + Sync,
218 T: DeserializeOwned,
219 E: TokenExtractor,
220{
221 type Rejection = AuthError;
222
223 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
224 let token = E::extract_token(parts).await?;
225
226 let decoder = Decoder::from_ref(state);
227 let token_data = decoder.decode(&token).await.map_err(map_jwt_error)?;
228
229 Ok(Claims {
230 claims: token_data.claims,
231 _extractor: PhantomData,
232 })
233 }
234}
235
236fn map_jwt_error(err: crate::Error) -> AuthError {
238 match err {
239 crate::Error::Jwt(e) => match e.kind() {
240 ErrorKind::ExpiredSignature => AuthError::ExpiredSignature,
241 ErrorKind::InvalidSignature => AuthError::InvalidSignature,
242 ErrorKind::InvalidAudience => AuthError::InvalidAudience,
243 ErrorKind::InvalidAlgorithm => AuthError::InvalidAlgorithm,
244 ErrorKind::InvalidToken => AuthError::InvalidToken,
245 ErrorKind::InvalidIssuer => AuthError::InvalidIssuer,
246 ErrorKind::InvalidSubject => AuthError::InvalidSubject,
247 ErrorKind::ImmatureSignature => AuthError::ImmatureSignature,
248 ErrorKind::MissingAlgorithm => AuthError::MissingAlgorithm,
249 ErrorKind::MissingRequiredClaim(claim) => {
250 AuthError::MissingRequiredClaim(claim.to_string())
251 }
252 _ => AuthError::InternalError,
253 },
254 _ => AuthError::InternalError,
255 }
256}
257
258#[derive(Debug, PartialEq, thiserror::Error)]
262pub enum AuthError {
263 #[error("Invalid token")]
265 InvalidToken,
266
267 #[error("Invalid signature")]
269 InvalidSignature,
270
271 #[error("Missing required claim: {0}")]
273 MissingRequiredClaim(String),
274
275 #[error("Expired signature")]
277 ExpiredSignature,
278
279 #[error("Invalid issuer")]
281 InvalidIssuer,
282
283 #[error("Invalid audience")]
285 InvalidAudience,
286
287 #[error("Invalid subject")]
289 InvalidSubject,
290
291 #[error("Immature signature")]
293 ImmatureSignature,
294
295 #[error("Invalid algorithm")]
297 InvalidAlgorithm,
298
299 #[error("Missing algorithm")]
301 MissingAlgorithm,
302
303 #[error("Missing token")]
305 MissingToken,
306
307 #[error("Internal error")]
309 InternalError,
310}
311
312impl IntoResponse for AuthError {
313 fn into_response(self) -> Response {
314 let (status, msg) = match self {
315 AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
316 AuthError::InvalidSignature => (StatusCode::UNAUTHORIZED, "Invalid signature"),
317 AuthError::MissingRequiredClaim(_) => {
318 (StatusCode::UNAUTHORIZED, "Missing required claim")
319 }
320 AuthError::ExpiredSignature => (StatusCode::UNAUTHORIZED, "Expired signature"),
321 AuthError::InvalidIssuer => (StatusCode::UNAUTHORIZED, "Invalid issuer"),
322 AuthError::InvalidAudience => (StatusCode::UNAUTHORIZED, "Invalid audience"),
323 AuthError::InvalidSubject => (StatusCode::UNAUTHORIZED, "Invalid subject"),
324 AuthError::ImmatureSignature => (StatusCode::UNAUTHORIZED, "Immature signature"),
325 AuthError::InvalidAlgorithm => (StatusCode::UNAUTHORIZED, "Invalid algorithm"),
326 AuthError::MissingAlgorithm => (StatusCode::UNAUTHORIZED, "Missing algorithm"),
327 AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"),
328 AuthError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"),
329 };
330
331 (status, msg).into_response()
332 }
333}
334
335#[cfg(test)]
336mod tests {
337
338 use super::*;
339 use axum::body::Body;
340 use axum::extract::Request;
341
342 #[test]
347 fn test_header_extractor_macro() {
348 define_header_extractor!(TestHeader, "x-test-header");
349 assert_eq!(TestHeader::value(), "x-test-header");
350 }
351
352 #[test]
353 fn test_cookie_extractor_macro() {
354 define_cookie_extractor!(TestCookie, "test_cookie");
355 assert_eq!(TestCookie::value(), "test_cookie");
356 }
357
358 #[tokio::test]
363 async fn test_map_jwt_error_expired_signature() {
364 use jsonwebtoken::errors::Error as JwtError;
365
366 let jwt_error = JwtError::from(ErrorKind::ExpiredSignature);
367 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
368 assert_eq!(auth_error, AuthError::ExpiredSignature);
369 }
370
371 #[tokio::test]
372 async fn test_map_jwt_error_invalid_signature() {
373 use jsonwebtoken::errors::Error as JwtError;
374
375 let jwt_error = JwtError::from(ErrorKind::InvalidSignature);
376 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
377 assert_eq!(auth_error, AuthError::InvalidSignature);
378 }
379
380 #[tokio::test]
381 async fn test_map_jwt_error_invalid_audience() {
382 use jsonwebtoken::errors::Error as JwtError;
383
384 let jwt_error = JwtError::from(ErrorKind::InvalidAudience);
385 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
386 assert_eq!(auth_error, AuthError::InvalidAudience);
387 }
388
389 #[tokio::test]
390 async fn test_map_jwt_error_invalid_algorithm() {
391 use jsonwebtoken::errors::Error as JwtError;
392
393 let jwt_error = JwtError::from(ErrorKind::InvalidAlgorithm);
394 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
395 assert_eq!(auth_error, AuthError::InvalidAlgorithm);
396 }
397
398 #[tokio::test]
399 async fn test_map_jwt_error_invalid_token() {
400 use jsonwebtoken::errors::Error as JwtError;
401
402 let jwt_error = JwtError::from(ErrorKind::InvalidToken);
403 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
404 assert_eq!(auth_error, AuthError::InvalidToken);
405 }
406
407 #[tokio::test]
408 async fn test_map_jwt_error_invalid_issuer() {
409 use jsonwebtoken::errors::Error as JwtError;
410
411 let jwt_error = JwtError::from(ErrorKind::InvalidIssuer);
412 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
413 assert_eq!(auth_error, AuthError::InvalidIssuer);
414 }
415
416 #[tokio::test]
417 async fn test_map_jwt_error_invalid_subject() {
418 use jsonwebtoken::errors::Error as JwtError;
419
420 let jwt_error = JwtError::from(ErrorKind::InvalidSubject);
421 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
422 assert_eq!(auth_error, AuthError::InvalidSubject);
423 }
424
425 #[tokio::test]
426 async fn test_map_jwt_error_immature_signature() {
427 use jsonwebtoken::errors::Error as JwtError;
428
429 let jwt_error = JwtError::from(ErrorKind::ImmatureSignature);
430 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
431 assert_eq!(auth_error, AuthError::ImmatureSignature);
432 }
433
434 #[tokio::test]
435 async fn test_map_jwt_error_missing_algorithm() {
436 use jsonwebtoken::errors::Error as JwtError;
437
438 let jwt_error = JwtError::from(ErrorKind::MissingAlgorithm);
439 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
440 assert_eq!(auth_error, AuthError::MissingAlgorithm);
441 }
442
443 #[tokio::test]
444 async fn test_map_jwt_error_missing_required_claim() {
445 use jsonwebtoken::errors::Error as JwtError;
446
447 let jwt_error = JwtError::from(ErrorKind::MissingRequiredClaim("sub".to_string()));
448 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
449 assert_eq!(
450 auth_error,
451 AuthError::MissingRequiredClaim("sub".to_string())
452 );
453 }
454
455 #[tokio::test]
456 async fn test_map_jwt_error_non_jwt_error() {
457 let error = crate::Error::KeyNotFound(Some("test_kid".to_string()));
458 let auth_error = map_jwt_error(error);
459 assert_eq!(auth_error, AuthError::InternalError);
460 }
461
462 #[tokio::test]
467 async fn test_bearer_token_extractor_valid() {
468 let req = Request::builder()
469 .header("Authorization", "Bearer test_token")
470 .body(Body::empty())
471 .unwrap();
472
473 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
474 assert!(token.is_ok());
475 assert_eq!(token.unwrap(), "test_token");
476 }
477
478 #[tokio::test]
479 async fn test_bearer_token_extractor_valid_long_token() {
480 let long_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
481 let req = Request::builder()
482 .header("Authorization", format!("Bearer {}", long_token))
483 .body(Body::empty())
484 .unwrap();
485
486 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
487 assert!(token.is_ok());
488 assert_eq!(token.unwrap(), long_token);
489 }
490
491 #[tokio::test]
492 async fn test_bearer_token_extractor_invalid_scheme() {
493 let req = Request::builder()
494 .header("Authorization", "Basic dXNlcjpwYXNz")
495 .body(Body::empty())
496 .unwrap();
497
498 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
499 assert!(token.is_err());
500 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
501 }
502
503 #[tokio::test]
504 async fn test_bearer_token_extractor_malformed_header() {
505 let req = Request::builder()
506 .header("Authorization", "BearerMissingSpace")
507 .body(Body::empty())
508 .unwrap();
509
510 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
511 assert!(token.is_err());
512 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
513 }
514
515 #[tokio::test]
516 async fn test_bearer_token_extractor_missing_header() {
517 let req = Request::builder().body(Body::empty()).unwrap();
518 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
519 assert!(token.is_err());
520 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
521 }
522
523 #[tokio::test]
524 async fn test_bearer_token_extractor_empty_token() {
525 let req = Request::builder()
526 .header("Authorization", "Bearer ")
527 .body(Body::empty())
528 .unwrap();
529
530 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
531 assert!(token.is_ok());
533 assert_eq!(token.unwrap(), "");
534 }
535
536 #[tokio::test]
537 async fn test_bearer_token_extractor_case_sensitivity() {
538 let req = Request::builder()
540 .header("Authorization", "bearer test_token")
541 .body(Body::empty())
542 .unwrap();
543
544 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
545 assert!(token.is_ok());
547 assert_eq!(token.unwrap(), "test_token");
548 }
549
550 #[tokio::test]
555 async fn test_header_token_extractor_valid() {
556 define_header_extractor!(XAuthToken, "x-auth-token");
557 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken>;
558
559 let req = Request::builder()
560 .header("x-auth-token", "test_token_123")
561 .body(Body::empty())
562 .unwrap();
563
564 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
565 assert!(token.is_ok());
566 assert_eq!(token.unwrap(), "test_token_123");
567 }
568
569 #[tokio::test]
570 async fn test_header_token_extractor_missing_header() {
571 define_header_extractor!(XAuthToken2, "x-auth-token");
572 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken2>;
573
574 let req = Request::builder().body(Body::empty()).unwrap();
575 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
576 assert!(token.is_err());
577 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
578 }
579
580 #[tokio::test]
581 async fn test_header_token_extractor_empty_value() {
582 define_header_extractor!(XAuthToken3, "x-auth-token");
583 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken3>;
584
585 let req = Request::builder()
586 .header("x-auth-token", "")
587 .body(Body::empty())
588 .unwrap();
589
590 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
591 assert!(token.is_ok());
592 assert_eq!(token.unwrap(), "");
593 }
594
595 #[tokio::test]
596 async fn test_header_token_extractor_special_characters() {
597 define_header_extractor!(XAuthToken4, "x-auth-token");
598 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken4>;
599
600 let req = Request::builder()
601 .header("x-auth-token", "token-with-special.chars_123")
602 .body(Body::empty())
603 .unwrap();
604
605 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
606 assert!(token.is_ok());
607 assert_eq!(token.unwrap(), "token-with-special.chars_123");
608 }
609
610 #[tokio::test]
611 async fn test_header_token_extractor_different_header_names() {
612 define_header_extractor!(ApiKey, "x-api-key");
613 type ApiKeyExtractor = HeaderTokenExtractor<ApiKey>;
614
615 let req = Request::builder()
616 .header("x-api-key", "api_key_value")
617 .header("x-auth-token", "auth_token_value")
618 .body(Body::empty())
619 .unwrap();
620
621 let token = ApiKeyExtractor::extract_token(&mut req.into_parts().0).await;
622 assert!(token.is_ok());
623 assert_eq!(token.unwrap(), "api_key_value");
624 }
625
626 #[tokio::test]
631 async fn test_cookie_token_extractor_valid() {
632 define_cookie_extractor!(AuthTokenCookie, "auth_token");
633 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie>;
634
635 let req = Request::builder()
636 .header("Cookie", "auth_token=my_jwt_token; other=value")
637 .body(Body::empty())
638 .unwrap();
639
640 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
641 assert!(token.is_ok());
642 assert_eq!(token.unwrap(), "my_jwt_token");
643 }
644
645 #[tokio::test]
646 async fn test_cookie_token_extractor_single_cookie() {
647 define_cookie_extractor!(AuthTokenCookie2, "auth_token");
648 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie2>;
649
650 let req = Request::builder()
651 .header("Cookie", "auth_token=my_jwt_token")
652 .body(Body::empty())
653 .unwrap();
654
655 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
656 assert!(token.is_ok());
657 assert_eq!(token.unwrap(), "my_jwt_token");
658 }
659
660 #[tokio::test]
661 async fn test_cookie_token_extractor_missing_cookie() {
662 define_cookie_extractor!(AuthTokenCookie3, "auth_token");
663 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie3>;
664
665 let req = Request::builder()
666 .header("Cookie", "other=value")
667 .body(Body::empty())
668 .unwrap();
669 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
670 assert!(token.is_err());
671 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
672 }
673
674 #[tokio::test]
675 async fn test_cookie_token_extractor_no_cookies() {
676 define_cookie_extractor!(AuthTokenCookie4, "auth_token");
677 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie4>;
678
679 let req = Request::builder().body(Body::empty()).unwrap();
680 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
681 assert!(token.is_err());
682 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
683 }
684
685 #[tokio::test]
686 async fn test_cookie_token_extractor_multiple_cookies() {
687 define_cookie_extractor!(AuthTokenCookie5, "auth_token");
688 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie5>;
689
690 let req = Request::builder()
691 .header("Cookie", "session=abc123; auth_token=my_jwt; user_id=456")
692 .body(Body::empty())
693 .unwrap();
694
695 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
696 assert!(token.is_ok());
697 assert_eq!(token.unwrap(), "my_jwt");
698 }
699
700 #[tokio::test]
701 async fn test_cookie_token_extractor_empty_value() {
702 define_cookie_extractor!(AuthTokenCookie6, "auth_token");
703 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie6>;
704
705 let req = Request::builder()
706 .header("Cookie", "auth_token=")
707 .body(Body::empty())
708 .unwrap();
709
710 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
711 assert!(token.is_ok());
712 assert_eq!(token.unwrap(), "");
713 }
714
715 #[tokio::test]
716 async fn test_cookie_token_extractor_with_spaces() {
717 define_cookie_extractor!(AuthTokenCookie7, "auth_token");
718 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie7>;
719
720 let req = Request::builder()
721 .header("Cookie", "auth_token=my_jwt_token; other=value")
722 .body(Body::empty())
723 .unwrap();
724
725 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
726 assert!(token.is_ok());
727 assert_eq!(token.unwrap(), "my_jwt_token");
728 }
729
730 #[tokio::test]
735 async fn test_auth_error_invalid_token_response() {
736 let response = AuthError::InvalidToken.into_response();
737 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
738 }
739
740 #[tokio::test]
741 async fn test_auth_error_expired_signature_response() {
742 let response = AuthError::ExpiredSignature.into_response();
743 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
744 }
745
746 #[tokio::test]
747 async fn test_auth_error_internal_error_response() {
748 let response = AuthError::InternalError.into_response();
749 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
750 }
751
752 #[tokio::test]
753 async fn test_auth_error_missing_token_response() {
754 let response = AuthError::MissingToken.into_response();
755 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
756 }
757
758 #[tokio::test]
759 async fn test_auth_error_missing_required_claim_response() {
760 let response = AuthError::MissingRequiredClaim("sub".to_string()).into_response();
761 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
762 }
763}