axum_jwt_auth/
axum.rs

1use async_trait::async_trait;
2use axum::extract::FromRef;
3use axum::http::StatusCode;
4use axum::response::Response;
5use axum::RequestPartsExt;
6use axum::{http::request::Parts, response::IntoResponse};
7use axum_extra::headers::authorization::Bearer;
8use axum_extra::headers::Authorization;
9use axum_extra::TypedHeader;
10use jsonwebtoken::errors::ErrorKind;
11use serde::de::DeserializeOwned;
12use serde::Deserialize;
13
14use crate::Decoder;
15
16/// A generic struct for holding the claims of a JWT token.
17#[derive(Debug, Deserialize)]
18pub struct Claims<T>(pub T);
19
20/// Trait for extracting tokens from request parts
21#[async_trait]
22pub trait TokenExtractor {
23    async fn extract_token(parts: &mut Parts) -> Result<String, AuthError>;
24}
25
26/// Default implementation using Bearer token
27pub struct BearerTokenExtractor;
28
29#[async_trait]
30impl TokenExtractor for BearerTokenExtractor {
31    async fn extract_token(parts: &mut Parts) -> Result<String, AuthError> {
32        let auth: TypedHeader<Authorization<Bearer>> =
33            parts.extract().await.map_err(|_| AuthError::MissingToken)?;
34
35        Ok(auth.token().to_string())
36    }
37}
38
39impl<S, T> axum::extract::FromRequestParts<S> for Claims<T>
40where
41    JwtDecoderState<T>: FromRef<S>,
42    S: Send + Sync,
43    T: DeserializeOwned,
44{
45    type Rejection = AuthError;
46
47    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
48        // TODO: Allow for custom token extractors?
49        let token = BearerTokenExtractor::extract_token(parts).await?;
50
51        let state = JwtDecoderState::from_ref(state);
52        let token_data = state
53            .decoder
54            .clone()
55            .decode(&token)
56            .await
57            .map_err(map_jwt_error)?;
58
59        Ok(Claims(token_data.claims))
60    }
61}
62
63/// Maps JWT errors to AuthError
64fn map_jwt_error(err: crate::Error) -> AuthError {
65    match err {
66        crate::Error::Jwt(e) => match e.kind() {
67            ErrorKind::ExpiredSignature => AuthError::ExpiredSignature,
68            ErrorKind::InvalidSignature => AuthError::InvalidSignature,
69            ErrorKind::InvalidAudience => AuthError::InvalidAudience,
70            ErrorKind::InvalidAlgorithm => AuthError::InvalidAlgorithm,
71            ErrorKind::InvalidToken => AuthError::InvalidToken,
72            ErrorKind::InvalidIssuer => AuthError::InvalidIssuer,
73            ErrorKind::InvalidSubject => AuthError::InvalidSubject,
74            ErrorKind::ImmatureSignature => AuthError::ImmatureSignature,
75            ErrorKind::MissingAlgorithm => AuthError::MissingAlgorithm,
76            ErrorKind::MissingRequiredClaim(claim) => {
77                AuthError::MissingRequiredClaim(claim.to_string())
78            }
79            _ => AuthError::InternalError,
80        },
81        _ => AuthError::InternalError,
82    }
83}
84
85/// An enum representing the possible errors that can occur when authenticating a request.
86/// These are sourced from the `jsonwebtoken` crate and defined here to implement `IntoResponse` for
87/// use in the `axum` framework.
88#[derive(Debug, PartialEq, thiserror::Error)]
89pub enum AuthError {
90    /// When the token is invalid
91    #[error("Invalid token")]
92    InvalidToken,
93
94    /// When the signature is invalid
95    #[error("Invalid signature")]
96    InvalidSignature,
97
98    // Validation errors
99    /// When a claim required by the validation is not present
100    #[error("Missing required claim: {0}")]
101    MissingRequiredClaim(String),
102
103    /// When a token's `exp` claim indicates that it has expired
104    #[error("Expired signature")]
105    ExpiredSignature,
106
107    /// When a token's `iss` claim does not match the expected issuer
108    #[error("Invalid issuer")]
109    InvalidIssuer,
110
111    /// When a token's `aud` claim does not match one of the expected audience values
112    #[error("Invalid audience")]
113    InvalidAudience,
114
115    /// When a token's `sub` claim does not match one of the expected subject values
116    #[error("Invalid subject")]
117    InvalidSubject,
118
119    /// When a token's `nbf` claim represents a time in the future
120    #[error("Immature signature")]
121    ImmatureSignature,
122
123    /// When the algorithm in the header doesn't match the one passed to `decode` or the encoding/decoding key
124    /// used doesn't match the alg requested
125    #[error("Invalid algorithm")]
126    InvalidAlgorithm,
127
128    /// When the Validation struct does not contain at least 1 algorithm
129    #[error("Missing algorithm")]
130    MissingAlgorithm,
131
132    /// When the request is missing a token
133    #[error("Missing token")]
134    MissingToken,
135
136    /// When an internal error occurs that doesn't fit into the other categories.
137    /// This is a catch-all error for any unexpected errors that occur such as
138    /// network errors, decoding errors, and cryptographic errors.
139    #[error("Internal error")]
140    InternalError,
141}
142
143impl IntoResponse for AuthError {
144    fn into_response(self) -> Response {
145        let (status, msg) = match self {
146            AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
147            AuthError::InvalidSignature => (StatusCode::UNAUTHORIZED, "Invalid signature"),
148            AuthError::MissingRequiredClaim(_) => {
149                (StatusCode::UNAUTHORIZED, "Missing required claim")
150            }
151            AuthError::ExpiredSignature => (StatusCode::UNAUTHORIZED, "Expired signature"),
152            AuthError::InvalidIssuer => (StatusCode::UNAUTHORIZED, "Invalid issuer"),
153            AuthError::InvalidAudience => (StatusCode::UNAUTHORIZED, "Invalid audience"),
154            AuthError::InvalidSubject => (StatusCode::UNAUTHORIZED, "Invalid subject"),
155            AuthError::ImmatureSignature => (StatusCode::UNAUTHORIZED, "Immature signature"),
156            AuthError::InvalidAlgorithm => (StatusCode::UNAUTHORIZED, "Invalid algorithm"),
157            AuthError::MissingAlgorithm => (StatusCode::UNAUTHORIZED, "Missing algorithm"),
158            AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"),
159            AuthError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"),
160        };
161
162        (status, msg).into_response()
163    }
164}
165
166#[derive(Clone)]
167pub struct JwtDecoderState<T> {
168    pub decoder: Decoder<T>,
169}
170
171impl<T> FromRef<JwtDecoderState<T>> for Decoder<T> {
172    fn from_ref(state: &JwtDecoderState<T>) -> Self {
173        state.decoder.clone()
174    }
175}
176
177#[cfg(test)]
178mod tests {
179
180    use super::*;
181    use axum::body::Body;
182    use axum::extract::Request;
183
184    #[tokio::test]
185    async fn test_map_jwt_error() {
186        use jsonwebtoken::errors::Error as JwtError;
187
188        let jwt_error = JwtError::from(ErrorKind::ExpiredSignature);
189        let auth_error = map_jwt_error(crate::Error::Jwt(jwt_error));
190        assert!(matches!(auth_error, AuthError::ExpiredSignature));
191    }
192
193    #[tokio::test]
194    async fn test_bearer_token_extractor() {
195        // Valid token
196        let req = Request::builder()
197            .header("Authorization", "Bearer test_token")
198            .body(Body::empty())
199            .unwrap();
200
201        let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
202        assert!(token.is_ok());
203        assert_eq!(token.unwrap(), "test_token");
204
205        // Invalid token
206        let req = Request::builder()
207            .header("Authorization", "Not a bearer token")
208            .body(Body::empty())
209            .unwrap();
210
211        let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
212        assert!(token.is_err());
213        assert_eq!(token.unwrap_err(), AuthError::MissingToken);
214
215        // Missing token
216        let req = Request::builder().body(Body::empty()).unwrap();
217        let token = BearerTokenExtractor::extract_token(&mut req.into_parts().0).await;
218        assert!(token.is_err());
219        assert_eq!(token.unwrap_err(), AuthError::MissingToken);
220    }
221}