1use async_trait::async_trait;
2use axum::extract::FromRef;
3use axum::http::StatusCode;
4use axum::response::Response;
5use axum::RequestPartsExt;
6use axum::{http::request::Parts, response::IntoResponse};
7use axum_extra::headers::authorization::Bearer;
8use axum_extra::headers::Authorization;
9use axum_extra::TypedHeader;
10use jsonwebtoken::errors::ErrorKind;
11use serde::de::DeserializeOwned;
12use serde::Deserialize;
13
14use crate::Decoder;
15
16#[derive(Debug, Deserialize)]
18pub struct Claims<T>(pub T);
19
20#[async_trait]
22pub trait TokenExtractor {
23 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError>;
24}
25
26pub struct BearerTokenExtractor;
28
29#[async_trait]
30impl TokenExtractor for BearerTokenExtractor {
31 async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
32 let auth: TypedHeader<Authorization<Bearer>> =
33 parts.extract().await.map_err(|_| AuthError::MissingToken)?;
34
35 Ok(auth.token().to_string())
36 }
37}
38
39impl<S, T> axum::extract::FromRequestParts<S> for Claims<T>
40where
41 JwtDecoderState<T>: FromRef<S>,
42 S: Send + Sync,
43 T: DeserializeOwned,
44{
45 type Rejection = AuthError;
46
47 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
48 let token = BearerTokenExtractor::extract_token(parts).await?;
50
51 let state = JwtDecoderState::from_ref(state);
52 let token_data = state
53 .decoder
54 .clone()
55 .decode(&token)
56 .await
57 .map_err(map_jwt_error)?;
58
59 Ok(Claims(token_data.claims))
60 }
61}
62
63fn map_jwt_error(err: crate::Error) -> AuthError {
65 match err {
66 crate::Error::Jwt(e) => match e.kind() {
67 ErrorKind::ExpiredSignature => AuthError::ExpiredSignature,
68 ErrorKind::InvalidSignature => AuthError::InvalidSignature,
69 ErrorKind::InvalidAudience => AuthError::InvalidAudience,
70 ErrorKind::InvalidAlgorithm => AuthError::InvalidAlgorithm,
71 ErrorKind::InvalidToken => AuthError::InvalidToken,
72 ErrorKind::InvalidIssuer => AuthError::InvalidIssuer,
73 ErrorKind::InvalidSubject => AuthError::InvalidSubject,
74 ErrorKind::ImmatureSignature => AuthError::ImmatureSignature,
75 ErrorKind::MissingAlgorithm => AuthError::MissingAlgorithm,
76 ErrorKind::MissingRequiredClaim(claim) => {
77 AuthError::MissingRequiredClaim(claim.to_string())
78 }
79 _ => AuthError::InternalError,
80 },
81 _ => AuthError::InternalError,
82 }
83}
84
85#[derive(Debug, PartialEq, thiserror::Error)]
89pub enum AuthError {
90 #[error("Invalid token")]
92 InvalidToken,
93
94 #[error("Invalid signature")]
96 InvalidSignature,
97
98 #[error("Missing required claim: {0}")]
101 MissingRequiredClaim(String),
102
103 #[error("Expired signature")]
105 ExpiredSignature,
106
107 #[error("Invalid issuer")]
109 InvalidIssuer,
110
111 #[error("Invalid audience")]
113 InvalidAudience,
114
115 #[error("Invalid subject")]
117 InvalidSubject,
118
119 #[error("Immature signature")]
121 ImmatureSignature,
122
123 #[error("Invalid algorithm")]
126 InvalidAlgorithm,
127
128 #[error("Missing algorithm")]
130 MissingAlgorithm,
131
132 #[error("Missing token")]
134 MissingToken,
135
136 #[error("Internal error")]
140 InternalError,
141}
142
143impl IntoResponse for AuthError {
144 fn into_response(self) -> Response {
145 let (status, msg) = match self {
146 AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
147 AuthError::InvalidSignature => (StatusCode::UNAUTHORIZED, "Invalid signature"),
148 AuthError::MissingRequiredClaim(_) => {
149 (StatusCode::UNAUTHORIZED, "Missing required claim")
150 }
151 AuthError::ExpiredSignature => (StatusCode::UNAUTHORIZED, "Expired signature"),
152 AuthError::InvalidIssuer => (StatusCode::UNAUTHORIZED, "Invalid issuer"),
153 AuthError::InvalidAudience => (StatusCode::UNAUTHORIZED, "Invalid audience"),
154 AuthError::InvalidSubject => (StatusCode::UNAUTHORIZED, "Invalid subject"),
155 AuthError::ImmatureSignature => (StatusCode::UNAUTHORIZED, "Immature signature"),
156 AuthError::InvalidAlgorithm => (StatusCode::UNAUTHORIZED, "Invalid algorithm"),
157 AuthError::MissingAlgorithm => (StatusCode::UNAUTHORIZED, "Missing algorithm"),
158 AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"),
159 AuthError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"),
160 };
161
162 (status, msg).into_response()
163 }
164}
165
166#[derive(Clone)]
167pub struct JwtDecoderState<T> {
168 pub decoder: Decoder<T>,
169}
170
171impl<T> FromRef<JwtDecoderState<T>> for Decoder<T> {
172 fn from_ref(state: &JwtDecoderState<T>) -> Self {
173 state.decoder.clone()
174 }
175}
176
177#[cfg(test)]
178mod tests {
179
180 use super::*;
181 use axum::body::Body;
182 use axum::extract::Request;
183
184 #[tokio::test]
185 async fn test_map_jwt_error() {
186 use jsonwebtoken::errors::Error as JwtError;
187
188 let jwt_error = JwtError::from(ErrorKind::ExpiredSignature);
189 let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
190 assert!(matches!(auth_error, AuthError::ExpiredSignature));
191 }
192
193 #[tokio::test]
194 async fn test_bearer_token_extractor() {
195 let req = Request::builder()
197 .header("Authorization", "Bearer test_token")
198 .body(Body::empty())
199 .unwrap();
200
201 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
202 assert!(token.is_ok());
203 assert_eq!(token.unwrap(), "test_token");
204
205 let req = Request::builder()
207 .header("Authorization", "Not a bearer token")
208 .body(Body::empty())
209 .unwrap();
210
211 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
212 assert!(token.is_err());
213 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
214
215 let req = Request::builder().body(Body::empty()).unwrap();
217 let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
218 assert!(token.is_err());
219 assert_eq!(token.unwrap_err(), AuthError::MissingToken);
220 }
221}