axum_jwt_auth/
axum.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
use axum::extract::FromRef;
use axum::http::StatusCode;
use axum::response::Response;
use axum::RequestPartsExt;
use axum::{http::request::Parts, response::IntoResponse};
use axum_extra::headers::authorization::Bearer;
use axum_extra::headers::Authorization;
use axum_extra::TypedHeader;
use serde::de::DeserializeOwned;
use serde::Deserialize;

use crate::{Decoder, JwtDecoder};

#[derive(Debug, Deserialize)]
pub struct Claims<T>(pub T);

pub enum AuthError {
    InvalidToken,
    MissingToken,
    ExpiredToken,
    InvalidSignature,
    InvalidAudience,
    InternalError,
}

impl IntoResponse for AuthError {
    fn into_response(self) -> Response {
        let (status, msg) = match self {
            AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid token"),
            AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing token"),
            AuthError::ExpiredToken => (StatusCode::UNAUTHORIZED, "Expired token"),
            AuthError::InvalidSignature => (StatusCode::UNAUTHORIZED, "Invalid signature"),
            AuthError::InvalidAudience => (StatusCode::UNAUTHORIZED, "Invalid audience"),
            AuthError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal error"),
        };

        (status, msg).into_response()
    }
}

#[derive(Clone, FromRef)]
pub struct JwtDecoderState {
    pub decoder: Decoder,
}

impl<S, T> axum::extract::FromRequestParts<S> for Claims<T>
where
    JwtDecoderState: FromRef<S>,
    S: Send + Sync,
    T: DeserializeOwned,
{
    type Rejection = AuthError;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        // `TypedHeader<Authorization<Bearer>>` extracts the auth token
        let auth: TypedHeader<Authorization<Bearer>> = parts
            .extract()
            .await
            .map_err(|_| Self::Rejection::MissingToken)?;

        let state = JwtDecoderState::from_ref(state);
        // `JwtDecoder::decode` decodes the token
        let token_data = state.decoder.decode(auth.token()).map_err(|e| match e {
            crate::Error::Jwt(e) => match e.kind() {
                jsonwebtoken::errors::ErrorKind::ExpiredSignature => Self::Rejection::ExpiredToken,
                jsonwebtoken::errors::ErrorKind::InvalidSignature => {
                    Self::Rejection::InvalidSignature
                }
                jsonwebtoken::errors::ErrorKind::InvalidAudience => {
                    Self::Rejection::InvalidAudience
                }
                _ => Self::Rejection::InvalidToken,
            },
            _ => Self::Rejection::InternalError,
        })?;

        Ok(token_data.claims)
    }
}