1use std::marker::PhantomData;
2
3use axum::RequestPartsExt;
4use axum::extract::FromRef;
5use axum::http::{StatusCode, header::HeaderName};
6use axum::response::Response;
7use axum::{http::request::Parts, response::IntoResponse};
8use axum_extra::TypedHeader;
9use axum_extra::headers::authorization::Bearer;
10use axum_extra::headers::{Authorization, Cookie};
11use jsonwebtoken::errors::ErrorKind;
12use serde::de::DeserializeOwned;
13
14use crate::Decoder;
15
16#[derive(Debug)]
42pub struct Claims<T, E = BearerTokenExtractor> {
43 pub claims: T,
45 _extractor: PhantomData<E>,
46}
47
48pub trait TokenExtractor: Send + Sync {
53 fn extract_token(
57 parts: &mut Parts,
58 ) -> impl std::future::Future<Output = Result<String, AuthError>> + Send;
59}
60
61pub struct BearerTokenExtractor;
65
66impl TokenExtractor for BearerTokenExtractor {
67 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
68 let auth: TypedHeader<Authorization<Bearer>> =
69 parts.extract().await.map_err(|_| AuthError::MissingToken)?;
70
71 Ok(auth.token().to_string())
72 }
73}
74
75pub trait ExtractorConfig: Send + Sync {
80 fn value() -> &'static str;
82}
83
84#[macro_export]
98macro_rules! define_header_extractor {
99 ($name:ident, $header:expr) => {
100 pub struct $name;
101 impl $crate::ExtractorConfig for $name {
102 fn value() -> &'static str {
103 $header
104 }
105 }
106 };
107}
108
109#[macro_export]
123macro_rules! define_cookie_extractor {
124 ($name:ident, $cookie:expr) => {
125 pub struct $name;
126 impl $crate::ExtractorConfig for $name {
127 fn value() -> &'static str {
128 $cookie
129 }
130 }
131 };
132}
133
134pub struct HeaderTokenExtractor<C: ExtractorConfig>(PhantomData<C>);
148
149impl<C: ExtractorConfig> TokenExtractor for HeaderTokenExtractor<C> {
150 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
151 let header_name = HeaderName::from_static(C::value());
152
153 parts
154 .headers
155 .get(&header_name)
156 .and_then(|h| h.to_str().ok())
157 .map(|s| s.to_string())
158 .ok_or(AuthError::MissingToken)
159 }
160}
161
162pub struct CookieTokenExtractor<C: ExtractorConfig>(PhantomData<C>);
176
177impl<C: ExtractorConfig> TokenExtractor for CookieTokenExtractor<C> {
178 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
179 let cookies: TypedHeader<Cookie> =
180 parts.extract().await.map_err(|_| AuthError::MissingToken)?;
181
182 cookies
183 .get(C::value())
184 .map(|s| s.to_string())
185 .ok_or(AuthError::MissingToken)
186 }
187}
188
189impl<S, T, E> axum::extract::FromRequestParts<S> for Claims<T, E>
190where
191 JwtDecoderState<T>: FromRef<S>,
192 S: Send + Sync,
193 T: DeserializeOwned,
194 E: TokenExtractor,
195{
196 type Rejection = AuthError;
197
198 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
199 let token = E::extract_token(parts).await?;
200
201 let state = JwtDecoderState::from_ref(state);
202 let token_data = state
203 .decoder
204 .clone()
205 .decode(&token)
206 .await
207 .map_err(map_jwt_error)?;
208
209 Ok(Claims {
210 claims: token_data.claims,
211 _extractor: PhantomData,
212 })
213 }
214}
215
216fn map_jwt_error(err: crate::Error) -> AuthError {
218 match err {
219 crate::Error::Jwt(e) => match e.kind() {
220 ErrorKind::ExpiredSignature => AuthError::ExpiredSignature,
221 ErrorKind::InvalidSignature => AuthError::InvalidSignature,
222 ErrorKind::InvalidAudience => AuthError::InvalidAudience,
223 ErrorKind::InvalidAlgorithm => AuthError::InvalidAlgorithm,
224 ErrorKind::InvalidToken => AuthError::InvalidToken,
225 ErrorKind::InvalidIssuer => AuthError::InvalidIssuer,
226 ErrorKind::InvalidSubject => AuthError::InvalidSubject,
227 ErrorKind::ImmatureSignature => AuthError::ImmatureSignature,
228 ErrorKind::MissingAlgorithm => AuthError::MissingAlgorithm,
229 ErrorKind::MissingRequiredClaim(claim) => {
230 AuthError::MissingRequiredClaim(claim.to_string())
231 }
232 _ => AuthError::InternalError,
233 },
234 _ => AuthError::InternalError,
235 }
236}
237
238#[derive(Debug, PartialEq, thiserror::Error)]
242pub enum AuthError {
243 #[error("Invalid token")]
245 InvalidToken,
246
247 #[error("Invalid signature")]
249 InvalidSignature,
250
251 #[error("Missing required claim: {0}")]
253 MissingRequiredClaim(String),
254
255 #[error("Expired signature")]
257 ExpiredSignature,
258
259 #[error("Invalid issuer")]
261 InvalidIssuer,
262
263 #[error("Invalid audience")]
265 InvalidAudience,
266
267 #[error("Invalid subject")]
269 InvalidSubject,
270
271 #[error("Immature signature")]
273 ImmatureSignature,
274
275 #[error("Invalid algorithm")]
277 InvalidAlgorithm,
278
279 #[error("Missing algorithm")]
281 MissingAlgorithm,
282
283 #[error("Missing token")]
285 MissingToken,
286
287 #[error("Internal error")]
289 InternalError,
290}
291
292impl IntoResponse for AuthError {
293 fn into_response(self) -> Response {
294 let (status, msg) = match self {
295 AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
296 AuthError::InvalidSignature => (StatusCode::UNAUTHORIZED, "Invalid signature"),
297 AuthError::MissingRequiredClaim(_) => {
298 (StatusCode::UNAUTHORIZED, "Missing required claim")
299 }
300 AuthError::ExpiredSignature => (StatusCode::UNAUTHORIZED, "Expired signature"),
301 AuthError::InvalidIssuer => (StatusCode::UNAUTHORIZED, "Invalid issuer"),
302 AuthError::InvalidAudience => (StatusCode::UNAUTHORIZED, "Invalid audience"),
303 AuthError::InvalidSubject => (StatusCode::UNAUTHORIZED, "Invalid subject"),
304 AuthError::ImmatureSignature => (StatusCode::UNAUTHORIZED, "Immature signature"),
305 AuthError::InvalidAlgorithm => (StatusCode::UNAUTHORIZED, "Invalid algorithm"),
306 AuthError::MissingAlgorithm => (StatusCode::UNAUTHORIZED, "Missing algorithm"),
307 AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"),
308 AuthError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"),
309 };
310
311 (status, msg).into_response()
312 }
313}
314
315#[derive(Clone)]
319pub struct JwtDecoderState<T> {
320 pub decoder: Decoder<T>,
322}
323
324impl<T> FromRef<JwtDecoderState<T>> for Decoder<T> {
325 fn from_ref(state: &JwtDecoderState<T>) -> Self {
326 state.decoder.clone()
327 }
328}
329
330#[cfg(test)]
331mod tests {
332
333 use super::*;
334 use axum::body::Body;
335 use axum::extract::Request;
336
337 #[test]
342 fn test_header_extractor_macro() {
343 define_header_extractor!(TestHeader, "x-test-header");
344 assert_eq!(TestHeader::value(), "x-test-header");
345 }
346
347 #[test]
348 fn test_cookie_extractor_macro() {
349 define_cookie_extractor!(TestCookie, "test_cookie");
350 assert_eq!(TestCookie::value(), "test_cookie");
351 }
352
353 #[tokio::test]
358 async fn test_map_jwt_error_expired_signature() {
359 use jsonwebtoken::errors::Error as JwtError;
360
361 let jwt_error = JwtError::from(ErrorKind::ExpiredSignature);
362 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
363 assert_eq!(auth_error, AuthError::ExpiredSignature);
364 }
365
366 #[tokio::test]
367 async fn test_map_jwt_error_invalid_signature() {
368 use jsonwebtoken::errors::Error as JwtError;
369
370 let jwt_error = JwtError::from(ErrorKind::InvalidSignature);
371 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
372 assert_eq!(auth_error, AuthError::InvalidSignature);
373 }
374
375 #[tokio::test]
376 async fn test_map_jwt_error_invalid_audience() {
377 use jsonwebtoken::errors::Error as JwtError;
378
379 let jwt_error = JwtError::from(ErrorKind::InvalidAudience);
380 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
381 assert_eq!(auth_error, AuthError::InvalidAudience);
382 }
383
384 #[tokio::test]
385 async fn test_map_jwt_error_invalid_algorithm() {
386 use jsonwebtoken::errors::Error as JwtError;
387
388 let jwt_error = JwtError::from(ErrorKind::InvalidAlgorithm);
389 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
390 assert_eq!(auth_error, AuthError::InvalidAlgorithm);
391 }
392
393 #[tokio::test]
394 async fn test_map_jwt_error_invalid_token() {
395 use jsonwebtoken::errors::Error as JwtError;
396
397 let jwt_error = JwtError::from(ErrorKind::InvalidToken);
398 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
399 assert_eq!(auth_error, AuthError::InvalidToken);
400 }
401
402 #[tokio::test]
403 async fn test_map_jwt_error_invalid_issuer() {
404 use jsonwebtoken::errors::Error as JwtError;
405
406 let jwt_error = JwtError::from(ErrorKind::InvalidIssuer);
407 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
408 assert_eq!(auth_error, AuthError::InvalidIssuer);
409 }
410
411 #[tokio::test]
412 async fn test_map_jwt_error_invalid_subject() {
413 use jsonwebtoken::errors::Error as JwtError;
414
415 let jwt_error = JwtError::from(ErrorKind::InvalidSubject);
416 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
417 assert_eq!(auth_error, AuthError::InvalidSubject);
418 }
419
420 #[tokio::test]
421 async fn test_map_jwt_error_immature_signature() {
422 use jsonwebtoken::errors::Error as JwtError;
423
424 let jwt_error = JwtError::from(ErrorKind::ImmatureSignature);
425 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
426 assert_eq!(auth_error, AuthError::ImmatureSignature);
427 }
428
429 #[tokio::test]
430 async fn test_map_jwt_error_missing_algorithm() {
431 use jsonwebtoken::errors::Error as JwtError;
432
433 let jwt_error = JwtError::from(ErrorKind::MissingAlgorithm);
434 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
435 assert_eq!(auth_error, AuthError::MissingAlgorithm);
436 }
437
438 #[tokio::test]
439 async fn test_map_jwt_error_missing_required_claim() {
440 use jsonwebtoken::errors::Error as JwtError;
441
442 let jwt_error = JwtError::from(ErrorKind::MissingRequiredClaim("sub".to_string()));
443 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
444 assert_eq!(
445 auth_error,
446 AuthError::MissingRequiredClaim("sub".to_string())
447 );
448 }
449
450 #[tokio::test]
451 async fn test_map_jwt_error_non_jwt_error() {
452 let error = crate::Error::KeyNotFound(Some("test_kid".to_string()));
453 let auth_error = map_jwt_error(error);
454 assert_eq!(auth_error, AuthError::InternalError);
455 }
456
457 #[tokio::test]
462 async fn test_bearer_token_extractor_valid() {
463 let req = Request::builder()
464 .header("Authorization", "Bearer test_token")
465 .body(Body::empty())
466 .unwrap();
467
468 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
469 assert!(token.is_ok());
470 assert_eq!(token.unwrap(), "test_token");
471 }
472
473 #[tokio::test]
474 async fn test_bearer_token_extractor_valid_long_token() {
475 let long_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
476 let req = Request::builder()
477 .header("Authorization", format!("Bearer {}", long_token))
478 .body(Body::empty())
479 .unwrap();
480
481 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
482 assert!(token.is_ok());
483 assert_eq!(token.unwrap(), long_token);
484 }
485
486 #[tokio::test]
487 async fn test_bearer_token_extractor_invalid_scheme() {
488 let req = Request::builder()
489 .header("Authorization", "Basic dXNlcjpwYXNz")
490 .body(Body::empty())
491 .unwrap();
492
493 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
494 assert!(token.is_err());
495 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
496 }
497
498 #[tokio::test]
499 async fn test_bearer_token_extractor_malformed_header() {
500 let req = Request::builder()
501 .header("Authorization", "BearerMissingSpace")
502 .body(Body::empty())
503 .unwrap();
504
505 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
506 assert!(token.is_err());
507 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
508 }
509
510 #[tokio::test]
511 async fn test_bearer_token_extractor_missing_header() {
512 let req = Request::builder().body(Body::empty()).unwrap();
513 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
514 assert!(token.is_err());
515 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
516 }
517
518 #[tokio::test]
519 async fn test_bearer_token_extractor_empty_token() {
520 let req = Request::builder()
521 .header("Authorization", "Bearer ")
522 .body(Body::empty())
523 .unwrap();
524
525 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
526 assert!(token.is_ok());
528 assert_eq!(token.unwrap(), "");
529 }
530
531 #[tokio::test]
532 async fn test_bearer_token_extractor_case_sensitivity() {
533 let req = Request::builder()
535 .header("Authorization", "bearer test_token")
536 .body(Body::empty())
537 .unwrap();
538
539 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
540 assert!(token.is_ok());
542 assert_eq!(token.unwrap(), "test_token");
543 }
544
545 #[tokio::test]
550 async fn test_header_token_extractor_valid() {
551 define_header_extractor!(XAuthToken, "x-auth-token");
552 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken>;
553
554 let req = Request::builder()
555 .header("x-auth-token", "test_token_123")
556 .body(Body::empty())
557 .unwrap();
558
559 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
560 assert!(token.is_ok());
561 assert_eq!(token.unwrap(), "test_token_123");
562 }
563
564 #[tokio::test]
565 async fn test_header_token_extractor_missing_header() {
566 define_header_extractor!(XAuthToken2, "x-auth-token");
567 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken2>;
568
569 let req = Request::builder().body(Body::empty()).unwrap();
570 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
571 assert!(token.is_err());
572 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
573 }
574
575 #[tokio::test]
576 async fn test_header_token_extractor_empty_value() {
577 define_header_extractor!(XAuthToken3, "x-auth-token");
578 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken3>;
579
580 let req = Request::builder()
581 .header("x-auth-token", "")
582 .body(Body::empty())
583 .unwrap();
584
585 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
586 assert!(token.is_ok());
587 assert_eq!(token.unwrap(), "");
588 }
589
590 #[tokio::test]
591 async fn test_header_token_extractor_special_characters() {
592 define_header_extractor!(XAuthToken4, "x-auth-token");
593 type XAuthTokenExtractor = HeaderTokenExtractor<XAuthToken4>;
594
595 let req = Request::builder()
596 .header("x-auth-token", "token-with-special.chars_123")
597 .body(Body::empty())
598 .unwrap();
599
600 let token = XAuthTokenExtractor::extract_token(&mut req.into_parts().0).await;
601 assert!(token.is_ok());
602 assert_eq!(token.unwrap(), "token-with-special.chars_123");
603 }
604
605 #[tokio::test]
606 async fn test_header_token_extractor_different_header_names() {
607 define_header_extractor!(ApiKey, "x-api-key");
608 type ApiKeyExtractor = HeaderTokenExtractor<ApiKey>;
609
610 let req = Request::builder()
611 .header("x-api-key", "api_key_value")
612 .header("x-auth-token", "auth_token_value")
613 .body(Body::empty())
614 .unwrap();
615
616 let token = ApiKeyExtractor::extract_token(&mut req.into_parts().0).await;
617 assert!(token.is_ok());
618 assert_eq!(token.unwrap(), "api_key_value");
619 }
620
621 #[tokio::test]
626 async fn test_cookie_token_extractor_valid() {
627 define_cookie_extractor!(AuthTokenCookie, "auth_token");
628 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie>;
629
630 let req = Request::builder()
631 .header("Cookie", "auth_token=my_jwt_token; other=value")
632 .body(Body::empty())
633 .unwrap();
634
635 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
636 assert!(token.is_ok());
637 assert_eq!(token.unwrap(), "my_jwt_token");
638 }
639
640 #[tokio::test]
641 async fn test_cookie_token_extractor_single_cookie() {
642 define_cookie_extractor!(AuthTokenCookie2, "auth_token");
643 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie2>;
644
645 let req = Request::builder()
646 .header("Cookie", "auth_token=my_jwt_token")
647 .body(Body::empty())
648 .unwrap();
649
650 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
651 assert!(token.is_ok());
652 assert_eq!(token.unwrap(), "my_jwt_token");
653 }
654
655 #[tokio::test]
656 async fn test_cookie_token_extractor_missing_cookie() {
657 define_cookie_extractor!(AuthTokenCookie3, "auth_token");
658 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie3>;
659
660 let req = Request::builder()
661 .header("Cookie", "other=value")
662 .body(Body::empty())
663 .unwrap();
664 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
665 assert!(token.is_err());
666 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
667 }
668
669 #[tokio::test]
670 async fn test_cookie_token_extractor_no_cookies() {
671 define_cookie_extractor!(AuthTokenCookie4, "auth_token");
672 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie4>;
673
674 let req = Request::builder().body(Body::empty()).unwrap();
675 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
676 assert!(token.is_err());
677 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
678 }
679
680 #[tokio::test]
681 async fn test_cookie_token_extractor_multiple_cookies() {
682 define_cookie_extractor!(AuthTokenCookie5, "auth_token");
683 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie5>;
684
685 let req = Request::builder()
686 .header("Cookie", "session=abc123; auth_token=my_jwt; user_id=456")
687 .body(Body::empty())
688 .unwrap();
689
690 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
691 assert!(token.is_ok());
692 assert_eq!(token.unwrap(), "my_jwt");
693 }
694
695 #[tokio::test]
696 async fn test_cookie_token_extractor_empty_value() {
697 define_cookie_extractor!(AuthTokenCookie6, "auth_token");
698 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie6>;
699
700 let req = Request::builder()
701 .header("Cookie", "auth_token=")
702 .body(Body::empty())
703 .unwrap();
704
705 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
706 assert!(token.is_ok());
707 assert_eq!(token.unwrap(), "");
708 }
709
710 #[tokio::test]
711 async fn test_cookie_token_extractor_with_spaces() {
712 define_cookie_extractor!(AuthTokenCookie7, "auth_token");
713 type AuthCookieExtractor = CookieTokenExtractor<AuthTokenCookie7>;
714
715 let req = Request::builder()
716 .header("Cookie", "auth_token=my_jwt_token; other=value")
717 .body(Body::empty())
718 .unwrap();
719
720 let token = AuthCookieExtractor::extract_token(&mut req.into_parts().0).await;
721 assert!(token.is_ok());
722 assert_eq!(token.unwrap(), "my_jwt_token");
723 }
724
725 #[tokio::test]
730 async fn test_auth_error_invalid_token_response() {
731 let response = AuthError::InvalidToken.into_response();
732 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
733 }
734
735 #[tokio::test]
736 async fn test_auth_error_expired_signature_response() {
737 let response = AuthError::ExpiredSignature.into_response();
738 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
739 }
740
741 #[tokio::test]
742 async fn test_auth_error_internal_error_response() {
743 let response = AuthError::InternalError.into_response();
744 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
745 }
746
747 #[tokio::test]
748 async fn test_auth_error_missing_token_response() {
749 let response = AuthError::MissingToken.into_response();
750 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
751 }
752
753 #[tokio::test]
754 async fn test_auth_error_missing_required_claim_response() {
755 let response = AuthError::MissingRequiredClaim("sub".to_string()).into_response();
756 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
757 }
758}